In [1]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

import torch
# from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Datasets

In [3]:
class Dataset():
    def __init__(self, root, phase, transformer=None):
        self.root=root
        self.phase=phase
        self.transformer=transformer
        self.image_list=sorted(os.listdir(root+phase+"/image/"))
        self.des_list=sorted(os.listdir(root+phase+"/description/"))
        self.label_list=sorted(os.listdir(root+phase+"/label/"))
        
    def __getitem__(self, index):
        img, des, label = self.get_data(index)
        return img['image'], des, label
        
    def __len__(self, ):
        return len(self.image_list)

    def get_data(self, index):
        # label
        try:
            label_file_name=self.label_list[index]
            lab_f=open(self.root+self.phase+"/label/"+label_file_name, "r")
            label=lab_f.read()
            if(len(label)>=10):
                label=label[0:10]
            elif(len(label)<10):
                margin=10-len(label)
                padding=" "*margin
                label=label+padding
            label=list(label.lower())

            # description
            des_file_name=self.des_list[index]
            des_f=open(self.root+self.phase+"/description/"+des_file_name, "r")
            des_text=des_f.read()
    #         des=des_text.split(" ")
            des=des_text

            # image
            img_file_name=self.image_list[index]
            image=cv2.imread(self.root+self.phase+"/image/"+img_file_name)
            img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            if(self.transformer!=None):
                transformed_img=self.transformer(image=img)
                img=transformed_img
        except:
            print(f"error: image name=>{img_file_name} des name=>{des_file_name} label name=>{label_file_name}")
            img={'image':torch.zeros((3,448,448))}
            des=''
            label=''
        return img, des, label
    
        
    

In [7]:
def build_dataloader(PATH, batch_size=2):
    IMAGE_SIZE = 448
    transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )
    
    dataloaders = {}
#     train_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=None)
    train_dataset=Dataset(root=PATH, phase="train", transformer=transformer)
    dataloaders["train"] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


    val_dataset=Dataset(root=PATH, phase="valid", transformer=transformer)
    dataloaders["val"] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    print(f"trainset:{len(train_dataset)} validset:{len(val_dataset)}")
    return dataloaders

## MODELs
 ![Untitled](../img/nickCLIP_arch.png)

### Image Encoder

In [8]:
class Image_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        resnet = torchvision.models.resnet34(pretrained = True)
        layers = [m for m in resnet.children()]
        
        self.backbone = nn.Sequential(*layers[:-2]) 
        self.head = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, padding=0,bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=32, out_channels=4, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(4),
                nn.ReLU(inplace=True),
                nn.Flatten(),
                nn.Linear(in_features=784, out_features=768)
            
            )
    def forward(self, x):
        out = self.backbone(x)
        out = self.head(out) # final output=> (1, 196)
        return out

In [9]:
Image_Enc = Image_Encoder()
Image_Enc.to(device)
torchsummary.summary(Image_Enc, (3,448,448))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
       BasicBlock-11         [-1, 64, 112, 112]               0
           Conv2d-12         [-1, 64, 112, 112]          36,864
      BatchNorm2d-13         [-1, 64, 112, 112]             128
             ReLU-14         [-1, 64, 1

### Text Encoder

In [10]:
class Text_Encoder(nn.Module):
    def __init__(self, device, pretrained='bert-base-uncased'):
        super().__init__()
        self.pretrained=pretrained
        self.device=device
        self.BERT = BertModel.from_pretrained(self.pretrained)
        self.tokenizer = BertTokenizer.from_pretrained(self.pretrained)
    
#     def preprocess(self, text):
#         tokenizer = BertTokenizer.from_pretrained(self.pretrained)
#         marked_text = "[CLS] " + text + " [SEP]"
#         tokenized_text = tokenizer.tokenize(marked_text)
#         indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
#         segments_ids = [1] * len(tokenized_text)
        
#         tokens_tensor = torch.tensor([indexed_tokens]).to(self.device)
#         segments_tensors = torch.tensor([segments_ids]).to(self.device)
        
#         return tokens_tensor, segments_tensors
        
#     def postprocess(self, encoded_layers):
# #         token_embeddings = torch.stack(encoded_layers, dim=0)
# #         token_embeddings = token_embeddings.permute(1,0,2)
#         token_vecs = encoded_layers[11][0]
#         sentence_embedding = torch.mean(token_vecs, dim=0)
        
#         return sentence_embedding
    
    def forward(self,x):
        tokens=self.tokenizer(x, 
                 add_special_tokens=True, 
                 max_length=100, 
                 padding="max_length",
                 truncation=True,
                 return_tensors="pt")
        tokens.to(self.device)
        output = self.BERT(**tokens)
        out = output.last_hidden_state.mean(axis=1)
        return out

In [11]:
Text_Enc=Text_Encoder(device=device)
Text_Enc.to(device)

## Decoder

In [13]:
class lstm_decoder(nn.Module):
    ''' Decodes hidden state output by encoder '''
    
    def __init__(self, input_size, hidden_size, num_layers = 1):

        '''
        : param input_size:     the number of features in the input X
        : param hidden_size:    the number of features in the hidden state h
        : param num_layers:     number of recurrent layers (i.e., 2 means there are
        :                       2 stacked LSTMs)
        '''
        
        super(lstm_decoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)           

    def forward(self, x_input, encoder_hidden_states):
        
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence;
        :                                   hidden gives the hidden state and cell state for the last
        :                                   element in the sequence 
 
        '''
#         print(x_input.shape)
        lstm_out, self.hidden = self.lstm(x_input, encoder_hidden_states)
#         print(f"lstm_out shape:{lstm_out.shape}") #lstm_out.shape->(1,1,768)
        output = self.linear(lstm_out)     
        
        return output, self.hidden

In [14]:
decoder=lstm_decoder(input_size=37, hidden_size=768)
decoder.to(device)
# torchsummary.summary(decoder)

lstm_decoder(
  (lstm): LSTM(37, 768, batch_first=True)
  (linear): Linear(in_features=768, out_features=37, bias=True)
)

![image.png](attachment:image.png)

## nickCLIP

In [15]:
class nickCLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder, decoder, device):
        super().__init__()
        self.image_encoder=image_encoder.to(device)
        self.text_encoder=text_encoder.to(device)
        self.decoder=decoder.to(device)
        
    def forward(self, image, description):
        img_embed=self.image_encoder(image).unsqueeze(0)
        sen_embed=self.text_encoder(description).unsqueeze(0)
        merged_embed=img_embed+sen_embed
        
        batch_size=image.shape[0]
        target_len=10
        initial_c=torch.zeros(1,batch_size,768).to(device)
        start_token = torch.zeros(batch_size,1,37).to(device) # start token??
        
        inputs=start_token
        hidden_state=merged_embed
        cell_state=initial_c
        
        outputs=torch.zeros(batch_size, target_len, 37)
#         for b in range(batch_size):
        for i in range(target_len):
            output,(hidden_state, cell_state)=decoder(inputs,(hidden_state, cell_state))
            inputs=output
#             print(f"{i}-output_shape:{output.shape}")
#             print(f"{outputs[:,i,:].shape}, {output.shape}" )
            outputs[:,i,:]=output.squeeze(1)
        
        
        return outputs
        

In [16]:
Image_Enc = Image_Encoder()
Text_Enc=Text_Encoder(device=device)
Decoder=lstm_decoder(input_size=37, hidden_size=768)
model=nickCLIP(image_encoder=Image_Enc, text_encoder=Text_Enc, decoder=Decoder, device=device)
model.to(device)

nickCLIP(
  (image_encoder): Image_Encoder(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d

## Train

In [17]:
# img, des, label
def collate_fn(batch):
    image_list = []
    des_list = []
    label_list = []
    
    for a,b,c in batch:
        image_list.append(a)
        des_list.append(b)
        label_list.append(c)

    return torch.stack(image_list, dim=0), des_list, label_list

In [18]:
#One-hot dict
def alp_to_mat(string_list,batch_size):
    one_hot_dict={'a':0,'b':1,'c':2,'d':3,'e':4,'f':5,'g':6,'h':7,'i':8,'j':9,'k':10,'l':11,'m':12,'n':13,'o':14,'p':15,'q':16,'r':17,'s':18,'t':19,'u':20,'v':21,'w':22,'x':23,'y':24,'z':25,'0':26,'1':27,'2':28,'3':29,'4':30,'5':31,'6':32,'7':33,'8':34,'9':35,' ':36}
    alp_to_num_list=[]
    batch_size=len(string_list)
    alp_to_mat_list=torch.zeros((batch_size,len(string_list[0]),37)) #(B,name_len,one_hot_len)
    for j in range(batch_size):
        for i in range(len(string_list[j])):
            char=string_list[j][i]
            if(char in one_hot_dict.keys()):
                pass
            else:
                char=' '
                
            alp_to_num_list.append(one_hot_dict[char])
            mat=torch.zeros(37)
    #         print(len(string_list))
    #         print(mat.shape)
            mat[one_hot_dict[char]]=1
#             print(f"{alp_to_mat_list.shape}-{mat.shape}")
            alp_to_mat_list[j,i,:]=mat
    #         result=torch.Tensor(alp_to_mat_list)
    return alp_to_mat_list
    


# test=[list("heloo     "),list("     fffff"),list("kdkdkkdkdk"),list("abcdefghij")]
# result=alp_to_mat(test,len(test))
# result.shape

## Train

In [20]:
def train_one_epoch(dataloaders, model, criterion, optimizer, device):
    train_loss = defaultdict(float)
    val_loss = defaultdict(float)
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
            
        running_loss = defaultdict(float)
        for index, batch in enumerate(dataloaders[phase]):
            images = batch[0].to(device)
            description = batch[1]
            label = batch[2]
            
            target=alp_to_mat(label,len(label))
            
            with torch.set_grad_enabled(phase == "train"):
                predictions = model(images, description)
                
            target_for_loss=target.view(-1,370).to(device)
            predictions_for_loss=predictions.view(-1,370).to(device)

            loss=criterion(predictions_for_loss, target_for_loss, torch.Tensor(predictions_for_loss.size(0)).cuda().fill_(1.0))
            
            if phase == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss["total_loss"] += loss.item()
                
                train_loss["total_loss"] += loss.item()
                
                if (index > 0) and (index % VERBOSE_FREQ) == 0:
                    text = f"<<<iteration:[{index}/{len(dataloaders[phase])}] - "
                    for k, v in running_loss.items():
                        text += f"{k}: {v/VERBOSE_FREQ:.4f}  "
                        running_loss[k] = 0.
                    print(text)
            else:
                val_loss["total_loss"] += loss.item()
        
    for k in train_loss.keys():
        train_loss[k] /= len(dataloaders["train"])
        val_loss[k] /= len(dataloaders["val"])
            
    return train_loss, val_loss

In [24]:
PATH="/workspace/team2/data/filter_50000/"

is_cuda = True

IMAGE_SIZE = 448
BATCH_SIZE = 16
VERBOSE_FREQ = 20
LR=0.001

IMAGE_ENC="RESNET34"
TEXT_ENC="BERT"
DECODER="LSTM"
num_epochs = 100
# DEVICE = torch.device('cuda' if torch.cuda.is_available and is_cuda else 'cpu')

dataloaders = build_dataloader(PATH=PATH, batch_size=BATCH_SIZE)


Image_Enc = Image_Encoder()
Text_Enc=Text_Encoder(device=device)
Decoder=lstm_decoder(input_size=37, hidden_size=768)
model=nickCLIP(image_encoder=Image_Enc, text_encoder=Text_Enc, decoder=Decoder, device=device)
model.to(device)

criterion = torch.nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

trainset:43864 validset:11340




In [18]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="nickclip_baseline",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "batch_size": BATCH_SIZE,
    "image encoder":IMAGE_ENC,
    "test encoder":TEXT_ENC,
    "decoder":DECODER,
    "dataset": PATH,
    "epochs": num_epochs,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mgomduribo[0m ([33murp[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
best_epoch = 0
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss, val_loss = train_one_epoch(dataloaders, model, criterion, optimizer, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
#     wandb.log({"Train Loss": train_loss['total_loss'],
#                "Val Loss": val_loss['total_loss'],})
    print(f"\nepoch:{epoch+1}/{num_epochs} - Train Loss: {train_loss['total_loss']:.4f}, Val Loss: {val_loss['total_loss']:.4f}\n")
    
    if (epoch+1) % 10 == 0:
        save_model(model.state_dict(), f'model_{epoch+1}.pth', save_dir=f"./trained_model/{IMAGE_ENC}_{TEXT_ENC}_{DECODER}")
wandb.finish()

<<<iteration:[20/2742] - total_loss: 0.8908  
<<<iteration:[40/2742] - total_loss: 0.8179  
<<<iteration:[60/2742] - total_loss: 0.8105  
