### For Collab

In [None]:
!git clone https://github.com/azrails/diplom
%cd diplom
%pip install requirements.txt
# download
!mkdir datasets 
%cd datasets
!wget http://images.cocodataset.org/zips/train2014.zip

# unzip
!unzip train2014.zip -d images/ && rm train2014.zip
# download
!wget https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip
# unzip
!unzip refcoco.zip && rm refcoco.zip

!python ../data_utils/prepare_data.py --data_root . --output_dir . --dataset refcoco --generate_mask
!python ../data_utils/folder_to_lmdb.py -j anns/refcoco/train.json -i images/train2014 -m masks/refcoco/ -o lmdb/refcoco
!python ../data_utils/folder_to_lmdb.py -j anns/refcoco/val.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
!python ../data_utils/folder_to_lmdb.py -j anns/refcoco/testA.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco
!python ../data_utils/folder_to_lmdb.py -j anns/refcoco/testB.json -i images/train2014/ -m masks/refcoco -o lmdb/refcoco

# clean
!rm -rf refcoco

In [None]:
import os
BASE_DIR = '/content/gdrive/MyDrive/stage_one'
os.makedirs(BASE_DIR, exist_ok=True)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/gdrive/MyDrive/stage_one/tb

### Imports

In [6]:
import random
import torch
import os
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from utils import checkpoints, config
from data_utils import tokenizer, dataset
from model import vit, bert
from torch.utils.tensorboard import SummaryWriter

#sets random
random_seed=42
random.seed(42)
torch.manual_seed(random_seed)

device="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu" )

match device:
    case "cuda":
        torch.cuda.manual_seed_all(random_seed)
    case "mps":
        torch.mps.manual_seed(random_seed)

### From Start

In [None]:
conf = config.load_data("configs/stage_one.yaml")
model = vit.StageOneEncoder(**conf['model']['VITEncoder'])
optimizer = torch.optim.AdamW(model.state_dict(),  **conf['optimizer_params'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, 
            milestones=conf['train_settings']['milestones'], 
            gamma=conf['train_settings']['lr_decay']
        )
losses = []
val_scores = []

### From Checkpoint

In [None]:
model, optimizer, scheduler, conf, losses, val_scores = checkpoints.load_checkpoint(BASE_DIR, "checkpoint_name", device)

### Prepare data

In [4]:

text_model = bert.BertEmbedding(conf['model']['text_backbone'])
tokenizer = tokenizer.get_bert_tokenizer(conf['model']['text_backbone'])
train_dataset = dataset.ReferenceDataset(
    **conf['data']['train'],
    tokenizer=tokenizer
)
train_data = DataLoader(
    train_dataset,
    batch_size=conf['train_settings']['batch_size'],
    shuffle=True,
    pin_memory=True,
    drop_last=True,
    num_workers=8
)
val_dataset = dataset.ReferenceDataset(
    **conf['data']['val'],
    tokenizer=tokenizer
)
val_data = DataLoader(
    val_dataset,
    batch_size=conf['train_settings']['batch_size'],
    shuffle=False,
    pin_memory=True,
    drop_last=True,
    num_workers=8
)
epoch = (conf['train_settings']['start_epoch'], conf['train_settings']['epochs'])


In [None]:
class Trainer:
    def __init__(self, model, text_model, optimizer, checkpoint_path, scheduler, device="cpu", tb_path=None):
        self.model = model.to(device)
        self.text_model = text_model.to(device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.tb_path = tb_path
        self.checkpoint_path
        self.loss_fn = torch.nn.TripletMarginLoss()
        if self.tb_path is not None:
            self.writer =  SummaryWriter(self.checkpoint_path, self.tb_path)
    
    def train(self, train_data, val_data, epochs, losses, val_scores, checkpoint_step=5, scheduler_step=1):
        for epoch in tqdm(range(epochs[0], epochs[1]), desc='Epochs'):
            train_loss = self.train_epoch(train_data)
            val_acc = self.validate(val_data)
            if self.tb_path is not None:
                self.write.add_scalar("Loss/train", train_loss, epoch)
                self.write.add_scalar("MSE/val", val_acc, epoch)
            losses.append(train_loss)
            val_scores.append(val_acc)
            if epoch % scheduler_step == 0:
                self.scheduler.step()
            if epoch % checkpoint_step == 0:
                conf['train_settings']['start_epoch'] = epoch + 1
                checkpoints.save_checkpoint(
                    self.checkpoint_path, 
                    f'epoch_{epoch}', 
                    conf, 
                    self.model, 
                    self.optimizer, 
                    self.scheduler, 
                    losses, 
                    val_scores
                    )

    def train_epoch(self, train_data):
        self.model.train()
        loss = 0
        for _, (_, mask_batch, negative_batch, sentence_batch, att_mask_batch) in enumerate(tqdm(train_data, desc="Training", leave=False)):
            mask_batch = mask_batch.to(self.device)
            negative_batch = mask_batch.to(self.device)
            sentence_batch = sentence_batch.to(device)
            att_mask_batch = att_mask_batch.to(device)
            self.optimizer.zero_grad()
            positive_predictions = self.model(mask_batch)
            negative_predictions = self.model(negative_batch)
            anchor_predictions = self.text_model(sentence_batch, att_mask_batch)
            step_loss = self.loss_fn(anchor_predictions, positive_predictions, negative_predictions)
            step_loss.backward()
            self.optimizer.step()
            loss += step_loss.cpu().detach().item() * len(mask_batch)
        return loss / len(train_data.dataset)

    @torch.no_grad()
    def validate(self, val_data):
        self.model.eval()
        val_acc = 0
        mse = torch.nn.MSELoss()
        for _, (_, mask_batch, negative_batch, sentence_batch, att_mask_batch) in enumerate(tqdm(val_data, desc="Validating", leave=False)):
            mask_batch = mask_batch.to(self.device)
            sentence_batch = sentence_batch.to(device)
            att_mask_batch = att_mask_batch.to(device)
            positive_predictions = self.model(mask_batch)
            anchor_predictions = self.text_model(sentence_batch, att_mask_batch)
            step_loss = mse(positive_predictions, anchor_predictions)
            val_acc += step_loss.cpu().detach().item() * len(mask_batch)
        return val_acc / len(val_data.dataset)

In [None]:
trainer = Trainer(model, text_model, optimizer, BASE_DIR, scheduler, device, 'tb')

In [None]:
trainer.train(train_data, val_data, epoch, losses, val_scores)