In [1]:
import os
BASE_DIR = './runs'
import math
os.makedirs(BASE_DIR, exist_ok=True)
import random
import torch
import os
from tqdm.notebook import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn.utils.rnn import pad_sequence
from utils import config
from data_utils import dataset
from model import vit
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec
from model.loss import ContrastiveLoss, ContrastiveSoftMax

#sets random
random_seed=42
random.seed(42)
torch.manual_seed(random_seed)

device="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu" )

match device:
    case "cuda":
        torch.cuda.manual_seed_all(random_seed)
    case "mps":
        torch.mps.manual_seed(random_seed)

import gc
gc.collect()
torch.cuda.empty_cache()

In [2]:
#GloVe
glove_file = 'glove.6B.300d.txt'
word2vec_temp_file = get_tmpfile("glove_word2vec.txt")
glove2word2vec(glove_file, word2vec_temp_file)
glove_model = KeyedVectors.load_word2vec_format(word2vec_temp_file)
glove_model = KeyedVectors.load_word2vec_format(word2vec_temp_file)

In [3]:
class Tokenizer:
    def __init__(self, embeding_model):
        self.embeddings = embeding_model
    
    def tokenize(self, sent_tokens):
        tokens_embeddings = []
        for t in sent_tokens:
            if t in self.embeddings:
                tokens_embeddings.append(self.embeddings[t])
            else:
                tokens_embeddings.append(np.zeros(300))
        return torch.FloatTensor(tokens_embeddings)

tokenizer = Tokenizer(glove_model)

In [4]:
conf = config.load_config("configs/stage_one.yaml")
image_embeder = vit.ImgEncoder(**conf['model']['ImageEncoder'])
text_embeder = vit.TextConvEncoder(**conf['model']['TextConvEncoder'])
siam_model = vit.SiamEncoder(**conf['model']['SiamEncoder'])
(sum(p.numel() for p in image_embeder.parameters() if p.requires_grad), 
sum(p.numel() for p in text_embeder.parameters() if p.requires_grad), 
sum(p.numel() for p in siam_model.parameters() if p.requires_grad))

(4513500, 642600, 13087201)

In [5]:
siam_model_params = torch.nn.ModuleList([image_embeder, text_embeder, siam_model])

In [6]:
optimizer = torch.optim.AdamW(siam_model_params.parameters(),  **conf['optimizer_params'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=conf['train_settings']['milestones'],
            gamma=conf['train_settings']['lr_decay']
        )
losses = []
val_scores = []

In [7]:
SEQ_LEN = conf['model']['TextEncoder']['seq_len']

def tokenize_sent_batch(sents_batch):
    tokenized_sentences = []
    length_list = []
    for sent in sents_batch:
        sent_tokens = sent.strip().split()
        if len(sent_tokens) > SEQ_LEN:
            sent_tokens = sent_tokens[:SEQ_LEN]
        tokenized_sentences.append(tokenizer.tokenize(sent_tokens))
        length_list.append(len(sent_tokens))
    #pad sequences to Text Embeder len
    tokenized_sentences[0] = torch.concat((tokenized_sentences[0], torch.zeros(SEQ_LEN - length_list[0], 300)), dim=0)
    seqs = pad_sequence(tokenized_sentences, batch_first=True)
    return seqs, torch.tensor(length_list, dtype=torch.uint8)


def pading_sentences_fn(data):
    masks_tensors, sents, labels = zip(*data)
    masks_tensors = torch.stack(masks_tensors)
    sents, lens = tokenize_sent_batch(sents)
    labels = torch.FloatTensor(labels)
    return masks_tensors, sents, lens, labels


train_dataset = dataset.ReferenceDataset(
    **conf['data']['train'],
)
train_data_coco_plus = dataset.ReferenceDataset(
    **conf['data']['train_plus'],
)
# union_train_dataset = ConcatDataset([train_dataset, train_data_coco_plus])
union_train_dataset = ConcatDataset([train_dataset, ])
train_data = DataLoader(
    union_train_dataset,
    batch_size=conf['train_settings']['batch_size'],
    shuffle=True,
    # pin_memory=True,
    drop_last=True,
    num_workers=6,
    collate_fn=pading_sentences_fn
)

val_dataset = dataset.ReferenceDataset(
    **conf['data']['val'],
)
val_dataset_plus = dataset.ReferenceDataset(
    **conf['data']['val_plus'],
)
# union_val_dataset = ConcatDataset([val_dataset, val_dataset_plus])
union_val_dataset = ConcatDataset([val_dataset, ])
val_data = DataLoader(
    union_val_dataset,
    batch_size=conf['train_settings']['batch_size'],
    shuffle=False,
    # pin_memory=True,
    drop_last=True,
    num_workers=6,
    collate_fn=pading_sentences_fn
)
epoch = (conf['train_settings']['start_epoch'], conf['train_settings']['epochs'])

In [8]:
import json
def save_checkpoint(path_to_checkpoints_folder, checkpoint_name, conf, text_embedder, img_embedder, siam_model, optimizer, scheduler, train_losses, val_scores):
    """
    Save model, config, score
    """
    path_to_checkpoint = os.path.join(path_to_checkpoints_folder, checkpoint_name)
    os.makedirs(path_to_checkpoint, exist_ok=True)
    config.dump_config(path_to_checkpoint, conf, 'config.yaml')

    path_to_metrics = os.path.join(path_to_checkpoint, 'metrics.json')
    with open(path_to_metrics, 'w') as f:
        json.dump({'train_losses': train_losses, 'val_scores': val_scores}, f, indent=4)
    
    path_to_model = os.path.join(path_to_checkpoint, 'model.pt')
    torch.save(
                {
                    'text_embedder_state_dict': text_embedder.state_dict(),
                    'img_embedder_state_dict': img_embedder.state_dict(),
                    'siam_model_state_dict': siam_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                
                },
                path_to_model
                )

In [9]:
torch.autograd.set_detect_anomaly(True)
# # backward hook with module name
# def get_backward_hook(module_name: str):
    
#     class BackwardHook:
#         name: str
            
#         def __init__(self, name):
#             self.name = name
            
#         def __call__(self, module_name, grad_input, grad_output):
#             for i, g_in in enumerate(grad_input):
#                 if torch.any(torch.isnan(g_in)):
#                     print(f"{module_name}'s {i}th input gradient is nan")
#                     raise Exception
#                 if torch.any(torch.isinf(g_in)):
#                     print(f"{module_name}'s {i}th output gradient is inf")
#                     raise Exception
#             for i, g_out in enumerate(grad_output):
#                 if torch.any(torch.isnan(g_out)):
#                     print(f"{module_name}'s {i}th output gradient is nan")
#                     raise Exception
#                 if torch.any(torch.isinf(g_out)):
#                     print(f"{module_name}'s {i}th output gradient is inf")
#                     raise Exception

#     return BackwardHook(module_name)

# for name, module in text_embeder.named_modules():
#     module.register_full_backward_hook(get_backward_hook(name))

# for name, module in image_embeder.named_modules():
#     module.register_full_backward_hook(get_backward_hook(name))

# for name, module in siam_model.named_modules():
#     module.register_full_backward_hook(get_backward_hook(name))

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f91d7d8d1e0>

In [10]:
class Trainer:
    def __init__(self, text_embedder, img_embedder, siam_model, optimizer, checkpoint_path, scheduler, conf, device="cpu", tb_path=None):
        self.text_embedder = text_embedder.to(device)
        self.img_embedder = img_embedder.to(device)
        self.siam_model = siam_model.to(device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.tb_path = tb_path
        self.checkpoint_path = checkpoint_path
        self.loss_fn = ContrastiveLoss(margin=2.0)
        self.conf = conf
        if self.tb_path is not None:
            self.writer =  SummaryWriter(self.checkpoint_path, self.tb_path)

    def train(self, train_data, val_data, epochs, losses, val_scores, checkpoint_step=1, scheduler_step=1):
        if self.tb_path is not None:
          for i, ls in enumerate(losses):
            self.writer.add_scalar("Loss_epoche/train", ls, i)
          for i, vl in enumerate(val_scores):
            self.writer.add_scalar("MSE/val", vl, i)
        for epoch in tqdm(range(epochs[0], epochs[1]), desc='Epochs'):
            train_loss = self.train_epoch(train_data, epoch)
            val_acc, cosine_acc, negative_mse = self.validate(val_data)
            if self.tb_path is not None:
                self.writer.add_scalar("Cosine/val", cosine_acc , epoch )
                self.writer.add_scalar("Contrastive_epoche/train", train_loss, epoch)
                self.writer.add_scalar("MSE/val", val_acc, epoch)
                self.writer.add_scalar("NEGATIVE_MSE/val", negative_mse, epoch)
            print(f"Epoch: {epoch}/{epochs} - Loss: {train_loss:.4f}")
            losses.append(train_loss)
            val_scores.append(val_acc)
            if epoch % scheduler_step == 0:
                self.scheduler.step()
            if epoch % checkpoint_step == 0:
                self.conf['train_settings']['start_epoch'] = epoch + 1
                save_checkpoint(
                    self.checkpoint_path,
                    f'epoch_{epoch}',
                    self.conf,
                    self.text_embedder,
                    self.img_embedder,
                    self.siam_model,
                    self.optimizer,
                    self.scheduler,
                    losses,
                    val_scores
                    )

    def train_epoch(self, train_data, epoch):
        self.img_embedder.train()
        self.text_embedder.train()
        self.siam_model.train()
        img_embedder = torch.compile(self.img_embedder, mode='max-autotune', fullgraph=True)
        text_embedder = torch.compile(self.text_embedder, mode='max-autotune', fullgraph=True)
        siam_model = torch.compile(self.siam_model, mode='max-autotune', fullgraph=True)
        # img_embedder = self.img_embedder
        # text_embedder = self.text_embedder
        # siam_model = self.siam_model
        loss = []
        for step, (masks_tensor, sents, _, labels) in enumerate(tqdm(train_data, desc="Training", leave=False)):
            #calc step
            self.optimizer.zero_grad()
            with torch.autocast(device_type=self.device):
                masks_tensor = masks_tensor.to(self.device)
                sents = sents.to(self.device)
                labels = labels.to(self.device)
                img_embeddings = img_embedder(masks_tensor)
                text_embeddings = text_embedder(sents)
                img_encoded, _ = siam_model(img_embeddings)
                text_encoded, _ = siam_model(text_embeddings)
                contrastive_loss = self.loss_fn(img_encoded, text_encoded, labels)
            step_loss = contrastive_loss
            step_loss.backward()
            self.optimizer.step()
            contrastive_loss = contrastive_loss.cpu().detach().item()
            loss.append(contrastive_loss)
            if self.tb_path is not None:
                self.writer.add_scalar("Contrastive_loss/train", contrastive_loss, epoch * len(train_data) + step)
        return np.mean(loss)

    def non_equal_distanse(self, labels, inverse=False):
        negative_labels = labels.clone()
        if inverse:
            negative_labels[labels==0]=-1
            negative_labels[labels==1]= 1    
        else:
            negative_labels[labels==0]=1
            negative_labels[labels==1]=0
        return negative_labels

    @torch.no_grad()
    def validate(self, val_data):
        self.img_embedder.eval()
        self.text_embedder.eval()
        self.siam_model.eval()
        mse_acc = []
        negative_mse = []
        coss_acc = []
        mse = torch.nn.functional.mse_loss
        cos_los = torch.nn.functional.cosine_embedding_loss
        for _, (masks_tensor, sents, _, labels) in enumerate(tqdm(val_data, desc="Validating", leave=False)):
            with torch.autocast(device_type=self.device):
                masks_tensor = masks_tensor.to(self.device)
                sents = sents.to(self.device)
                labels = labels.to(self.device)
                img_embeddings = self.img_embedder(masks_tensor)
                text_embeddings = self.text_embedder(sents)
                img_encoded, _ = self.siam_model(img_embeddings)
                text_encoded, _ = self.siam_model(text_embeddings)
                mse_loss = mse(img_encoded*labels[:,None,None], text_encoded*labels[:, None, None])
                negative_labels = self.non_equal_distanse(labels)
                negative_mse_loss = mse(img_encoded*negative_labels[:,None,None], text_encoded*negative_labels[:, None, None])
                cos_neg_labels = self.non_equal_distanse(labels, inverse=True)
                cs = cos_los(img_encoded, text_encoded, cos_neg_labels)
            coss_acc.append(cs.cpu().detach().item())
            mse_acc.append(mse_loss.cpu().detach().item())
            negative_mse.append(negative_mse_loss.cpu().detach().item())
        return np.mean(mse_acc), np.mean(coss_acc), np.mean(negative_mse)

In [11]:
trainer = Trainer(text_embeder, image_embeder, siam_model, optimizer, BASE_DIR, scheduler, conf, device, 'tb')
trainer.train(train_data, val_data, epoch, losses, val_scores)

Epochs:   0%|          | 0/300 [00:00<?, ?it/s]

Training:   0%|          | 0/757 [00:00<?, ?it/s]

W0521 16:23:45.054000 140270230462464 torch/_inductor/utils.py:945] [0/0] not enough SMs to use max_autotune_gemm mode


Validating:   0%|          | 0/68 [00:00<?, ?it/s]

Epoch: 0/(0, 300) - Loss: 1.1363


Training:   0%|          | 0/757 [00:00<?, ?it/s]

Validating:   0%|          | 0/68 [00:00<?, ?it/s]

Epoch: 1/(0, 300) - Loss: 1.0098


Training:   0%|          | 0/757 [00:00<?, ?it/s]

Validating:   0%|          | 0/68 [00:00<?, ?it/s]

Epoch: 2/(0, 300) - Loss: 1.0100


Training:   0%|          | 0/757 [00:00<?, ?it/s]

Validating:   0%|          | 0/68 [00:00<?, ?it/s]

Epoch: 3/(0, 300) - Loss: 1.0103


Training:   0%|          | 0/757 [00:00<?, ?it/s]

Validating:   0%|          | 0/68 [00:00<?, ?it/s]

Epoch: 4/(0, 300) - Loss: 1.0154


Training:   0%|          | 0/757 [00:00<?, ?it/s]

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

In [None]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()