In [1]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

import torch
# from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Datasets

In [3]:
class Dataset():
    def __init__(self, root, phase, transformer=None):
        self.root=root
        self.phase=phase
        self.transformer=transformer
        self.image_list=sorted(os.listdir(root+phase+"/image/"))
        self.des_list=sorted(os.listdir(root+phase+"/description/"))
        self.label_list=sorted(os.listdir(root+phase+"/label/"))
        
    def __getitem__(self, index):
        img, des, label = self.get_data(index)
        return img['image'], des, label
        
    def __len__(self, ):
        return len(self.image_list)

    def get_data(self, index):
        # label
        try:
            label_file_name=self.label_list[index]
            lab_f=open(self.root+self.phase+"/label/"+label_file_name, "r")
            label=lab_f.read()
            if(len(label)>=10):
                label=label[0:10]
            elif(len(label)<10):
                margin=10-len(label)
                padding=" "*margin
                label=label+padding
            label=list(label.lower())

            # description
            des_file_name=self.des_list[index]
            des_f=open(self.root+self.phase+"/description/"+des_file_name, "r")
            des_text=des_f.read()
    #         des=des_text.split(" ")
            des=des_text

            # image
            img_file_name=self.image_list[index]
            image=cv2.imread(self.root+self.phase+"/image/"+img_file_name)
            img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            if(self.transformer!=None):
                transformed_img=self.transformer(image=img)
                img=transformed_img
        except:
            print(f"error: image name=>{img_file_name} des name=>{des_file_name} label name=>{label_file_name}")
            img={'image':torch.zeros((3,448,448))}
            des=''
            label=''
        return img, des, label
    
        
    

In [4]:
IMAGE_SIZE=448
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )

In [5]:
root='/workspace/team2/data/filter_50000/'
train_dataset=Dataset(root=root, phase="train", transformer=transformer)
len(train_dataset)

43864

In [40]:
@interact(index=(0, len(train_dataset)-1))
def show_sample(index):
    img, des, label = train_dataset[index]
    image=img.permute(1,2,0).numpy()
    print(f'desciption: {des}')
    print(f'label: {label}-> label length:{len(label)}')
    print(image.shape)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

interactive(children=(IntSlider(value=21931, description='index', max=43863), Output()), _dom_classes=('widget…

In [6]:
def build_dataloader(PATH, batch_size=2):
    IMAGE_SIZE = 448
    transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )
    
    dataloaders = {}
#     train_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=None)
    train_dataset=Dataset(root=PATH, phase="train", transformer=transformer)
    dataloaders["train"] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


    val_dataset=Dataset(root=PATH, phase="valid", transformer=transformer)
    dataloaders["val"] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    print(f"trainset:{len(train_dataset)} validset:{len(val_dataset)}")
    return dataloaders

## MODELs
 ![Untitled](../img/nickCLIP_arch.png)

### Image Encoder

In [7]:
class Image_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        resnet = torchvision.models.resnet34(pretrained = True)
        layers = [m for m in resnet.children()]
        
        self.backbone = nn.Sequential(*layers[:-2]) 
        self.head = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, padding=0,bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=32, out_channels=4, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(4),
                nn.ReLU(inplace=True),
                nn.Flatten(),
                nn.Linear(in_features=784, out_features=768)
            
            )
        
#         self.head = nn.Sequential(
#                 nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, padding=0,bias=False),
#                 nn.BatchNorm2d(256),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(128),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(32),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(1),
#                 nn.ReLU(inplace=True),
#                 nn.Flatten()
#             )
    def forward(self, x):
        out = self.backbone(x)
        out = self.head(out) # final output=> (1, 196)
        return out

In [8]:
Image_Enc = Image_Encoder()
Image_Enc.to(device)
torchsummary.summary(Image_Enc, (3,448,448))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
       BasicBlock-11         [-1, 64, 112, 112]               0
           Conv2d-12         [-1, 64, 112, 112]          36,864
      BatchNorm2d-13         [-1, 64, 112, 112]             128
             ReLU-14         [-1, 64, 1

### Text Encoder

In [9]:
class Text_Encoder(nn.Module):
    def __init__(self, device, pretrained='bert-base-uncased'):
        super().__init__()
        self.pretrained=pretrained
        self.device=device
        self.BERT = BertModel.from_pretrained(self.pretrained)
        self.tokenizer = BertTokenizer.from_pretrained(self.pretrained)
    
    def preprocess(self, text):
        tokenizer = BertTokenizer.from_pretrained(self.pretrained)
        marked_text = "[CLS] " + text + " [SEP]"
        tokenized_text = tokenizer.tokenize(marked_text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        segments_ids = [1] * len(tokenized_text)
        
        tokens_tensor = torch.tensor([indexed_tokens]).to(self.device)
        segments_tensors = torch.tensor([segments_ids]).to(self.device)
        
        return tokens_tensor, segments_tensors
        
    def postprocess(self, encoded_layers):
#         token_embeddings = torch.stack(encoded_layers, dim=0)
#         token_embeddings = token_embeddings.permute(1,0,2)
        token_vecs = encoded_layers[11][0]
        sentence_embedding = torch.mean(token_vecs, dim=0)
        
        return sentence_embedding
    
#     def forward(self, x):
#         tokens_tensor, segments_tensors = self.preprocess(x)
        
#         encoded_layers, _ = self.BERT(tokens_tensor, segments_tensors)
        
#         sentence_embedding = self.postprocess(encoded_layers)
        
#         return sentence_embedding
    def forward(self,x):
        tokens=self.tokenizer(x, 
                 add_special_tokens=True, 
                 max_length=100, 
                 padding="max_length",
                 truncation=True,
                 return_tensors="pt")
        tokens.to(self.device)
        output = self.BERT(**tokens)
        out = output.last_hidden_state.mean(axis=1)
        return out

In [10]:
Text_Enc=Text_Encoder(device=device)

In [11]:
Text_Enc.to(device)

Text_Encoder(
  (BERT): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

In [12]:
@interact(index=(0, len(train_dataset)-1))
def show_sample(index):
    img, des, label = train_dataset[index]
#     image=img['image'].permute(1,2,0).numpy()
    image=img.unsqueeze(0).to(device)
    img_embed=Image_Enc(image)
    sen_embed=Text_Enc(des)
    print(f'desciption: {des}')
    print(f'label: {label}')
    print(f"img shape:{image.shape}")
    print(f'sen_emb shape:{sen_embed.shape}')
    print(f'img_embed shape:{img_embed.shape}')
    image=image.squeeze(0).permute(1,2,0).cpu().numpy()
    plt.imshow(image)
    plt.axis('off')
    plt.show()

interactive(children=(IntSlider(value=21931, description='index', max=43863), Output()), _dom_classes=('widget…

## Decoder

In [12]:
class lstm_decoder(nn.Module):
    ''' Decodes hidden state output by encoder '''
    
    def __init__(self, input_size, hidden_size, num_layers = 1):

        '''
        : param input_size:     the number of features in the input X
        : param hidden_size:    the number of features in the hidden state h
        : param num_layers:     number of recurrent layers (i.e., 2 means there are
        :                       2 stacked LSTMs)
        '''
        
        super(lstm_decoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)           

    def forward(self, x_input, encoder_hidden_states):
        
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence;
        :                                   hidden gives the hidden state and cell state for the last
        :                                   element in the sequence 
 
        '''
#         print(x_input.shape)
        lstm_out, self.hidden = self.lstm(x_input, encoder_hidden_states)
#         print(f"lstm_out shape:{lstm_out.shape}") #lstm_out.shape->(1,1,768)
        output = self.linear(lstm_out)     
        
        return output, self.hidden

In [13]:
decoder=lstm_decoder(input_size=37, hidden_size=768)
decoder.to(device)
# torchsummary.summary(decoder)

lstm_decoder(
  (lstm): LSTM(37, 768, batch_first=True)
  (linear): Linear(in_features=768, out_features=37, bias=True)
)

In [14]:
img, des, label = train_dataset[0]
print(des)
image=img.unsqueeze(0).to(device)
img_embed=Image_Enc(image).unsqueeze(0)
sen_embed=Text_Enc(des).unsqueeze(0).unsqueeze(0)
merged_embed=img_embed+sen_embed

Supermodel turned actress turned horse


In [16]:
image.shape

torch.Size([1, 3, 448, 448])

In [19]:
img_embed.shape

torch.Size([1, 1, 768])

In [20]:
sen_embed.shape

torch.Size([1, 1, 768])

In [21]:
merged_embed.shape

torch.Size([1, 1, 768])

In [22]:
initial_c=torch.zeros(1,1,768).to(device)
inputs = torch.zeros(1,1,37).to(device)

# input->(1,27), merged_embed->(1,1,768), initial_c->(1,1,768)
output,(hidden_state, cell_state)=decoder(inputs,(merged_embed, initial_c))

In [23]:
print(output.shape) #[1(B), 1, 27]
print(hidden_state.shape) #[1, 1(B), 768]
print(cell_state.shape) #[1, 1(B), 768]

torch.Size([1, 1, 37])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])


In [17]:
PATH="/workspace/team2/data/filter_50000/"
BATCH_SIZE = 16
dataloaders = build_dataloader(PATH=PATH, batch_size=BATCH_SIZE)
for index, batch in enumerate(dataloaders['train']):
    images = batch[0]
    description = batch[1]
    label = batch[2]
    print(f"{index}-{images.shape} {len(description)} {len(label)}")
    
    img_embed=Image_Enc(images.to(device))
    sen_embed=Text_Enc(description)
    concat_embed=img_embed+sen_embed
    
    print(f"{img_embed.shape}, {sen_embed.shape}. {concat_embed.shape}")
    initial_c=torch.zeros(1,BATCH_SIZE,768).to(device)
    start_token = torch.zeros(BATCH_SIZE,1,37).to(device) # start token??
    target_len=10
    outputs=torch.zeros(BATCH_SIZE, target_len, 37)

    inputs=start_token
    hidden_state=concat_embed.unsqueeze(0)
    cell_state=initial_c
    print(f"initial inputs:{inputs.shape}, hidden_state:{hidden_state.shape}, cell_state:{cell_state.shape}")
    
    for i in range(target_len):
            output,(hidden_state, cell_state)=decoder(inputs,(hidden_state, cell_state))
            inputs=output
#             print(f"{i}-output_shape:{output.shape}")
#             print(f"{outputs[:,i,:].shape}, {output.shape}" )
            outputs[:,i,:]=output.squeeze(1)
    print(outputs.shape)
    
    
    target=alp_to_mat(label,len(label))
#     output=model(image, description)
    
    target_for_loss=target.view(-1,370).to(device)
    output_for_loss=outputs.view(-1,370).to(device)
#     print(f"out: {outputs}\n------------------------------")
    print(f"out changed shape: {target_for_loss}")
    
#     loss=criterion(output_for_loss, target_for_loss, torch.Tensor(output_for_loss.size(0)).cuda().fill_(1.0))
#     print(f"output shape:{output.shape} <=> label shape:{len(label)}, target shape:{target.shape}")
#     print(f"loss: {loss}")
#     print("\n")

trainset:43864 validset:11340
0-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-3.9634e-02,  1.9520e-02,  3.9103e-02,  ..., -1.1104e-02,
           1.1740e-02,  9.3796e-02],
         [-3.0007e-02, -1.6612e-03,  2.5583e-02,  ...,  1.9044e-02,
           4.6704e-03,  5.3334e-02],
         [-2.5383e-02, -1.1842e-02,  2.7639e-02,  ...,  2.8665e-02,
           8.9096e-03,  3.5920e-02],
         ...,
         [-2.4674e-02, -2.1827e-02,  3.5728e-02,  ...,  3.1756e-02,
           1.5382e-02,  1.0883e-02],
         [-2.4801e-02, -2.1698e-02,  3.5886e-02,  ...,  3.1401e-02,
           1.5600e-02,  1.0008e-02],
         [-2.4855e-02, -2.1561e-02,  3.5955e-02,  ...,  3.1135e-02,
           1.5717e-02,  9.4895e-03]],

        [[-5.1369e-02,  8.8383e-03, -4.6239e-03,  ...,  8.5014e-03,
     

2-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0206, -0.0045, -0.0035,  ...,  0.0069,  0.0123, -0.0292],
         [-0.0205, -0.0110,  0.0183,  ...,  0.0185,  0.0051, -0.0076],
         [-0.0210, -0.0189,  0.0296,  ...,  0.0243,  0.0055,  0.0025],
         ...,
         [-0.0243, -0.0230,  0.0366,  ...,  0.0308,  0.0150,  0.0090],
         [-0.0245, -0.0225,  0.0364,  ...,  0.0307,  0.0155,  0.0090],
         [-0.0246, -0.0221,  0.0362,  ...,  0.0307,  0.0157,  0.0089]],

        [[-0.0660, -0.0238,  0.0591,  ..., -0.0402,  0.0224, -0.0106],
         [-0.0542, -0.0288,  0.0467,  ...,  0.0037,  0.0174,  0.0093],
         [-0.0444, -0.0269,  0.0428,  ...,  0.0212,  0.0132,  0.0136],
         ...,
         [-0.0268, -0.0216,  0.0360,  ...,  0.0319,  0.0156,  0.0098],
         

5-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0090,  0.0038,  0.0470,  ...,  0.0091,  0.0461, -0.0220],
         [-0.0213, -0.0253,  0.0400,  ...,  0.0188,  0.0291, -0.0054],
         [-0.0212, -0.0298,  0.0350,  ...,  0.0286,  0.0200,  0.0043],
         ...,
         [-0.0249, -0.0229,  0.0352,  ...,  0.0323,  0.0157,  0.0091],
         [-0.0249, -0.0222,  0.0355,  ...,  0.0318,  0.0158,  0.0090],
         [-0.0249, -0.0218,  0.0357,  ...,  0.0314,  0.0158,  0.0089]],

        [[-0.0870,  0.0176, -0.0320,  ...,  0.0006,  0.0283,  0.0026],
         [-0.0601, -0.0117, -0.0252,  ...,  0.0250,  0.0250,  0.0053],
         [-0.0455, -0.0204, -0.0023,  ...,  0.0304,  0.0174,  0.0124],
         ...,
         [-0.0263, -0.0218,  0.0345,  ...,  0.0329,  0.0156,  0.0099],
         

8-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0698, -0.0233,  0.0073,  ..., -0.0134,  0.0167,  0.0384],
         [-0.0513, -0.0212,  0.0141,  ...,  0.0163,  0.0062,  0.0231],
         [-0.0396, -0.0225,  0.0211,  ...,  0.0284,  0.0058,  0.0166],
         ...,
         [-0.0256, -0.0215,  0.0350,  ...,  0.0317,  0.0146,  0.0095],
         [-0.0253, -0.0213,  0.0355,  ...,  0.0313,  0.0151,  0.0092],
         [-0.0251, -0.0213,  0.0357,  ...,  0.0310,  0.0154,  0.0090]],

        [[-0.0983, -0.0321,  0.0134,  ...,  0.0204, -0.0226, -0.0616],
         [-0.0560, -0.0355,  0.0219,  ...,  0.0325, -0.0003, -0.0300],
         [-0.0396, -0.0324,  0.0288,  ...,  0.0361,  0.0106, -0.0114],
         ...,
         [-0.0259, -0.0229,  0.0355,  ...,  0.0323,  0.0157,  0.0075],
         

11-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0294,  0.0538,  0.0021,  ...,  0.0143,  0.0201,  0.0281],
         [-0.0365,  0.0190,  0.0076,  ...,  0.0206,  0.0199,  0.0191],
         [-0.0353, -0.0005,  0.0176,  ...,  0.0235,  0.0171,  0.0175],
         ...,
         [-0.0259, -0.0204,  0.0349,  ...,  0.0303,  0.0156,  0.0102],
         [-0.0254, -0.0208,  0.0354,  ...,  0.0305,  0.0157,  0.0097],
         [-0.0252, -0.0210,  0.0357,  ...,  0.0306,  0.0157,  0.0093]],

        [[-0.0713, -0.0401,  0.0606,  ..., -0.0425,  0.0580,  0.0691],
         [-0.0517, -0.0353,  0.0466,  ..., -0.0014,  0.0414,  0.0447],
         [-0.0391, -0.0308,  0.0438,  ...,  0.0168,  0.0313,  0.0317],
         ...,
         [-0.0256, -0.0225,  0.0370,  ...,  0.0303,  0.0168,  0.0106],
        

14-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0194, -0.0026,  0.0533,  ...,  0.0027, -0.0288,  0.0517],
         [-0.0258, -0.0158,  0.0372,  ...,  0.0256, -0.0105,  0.0334],
         [-0.0270, -0.0198,  0.0346,  ...,  0.0326,  0.0006,  0.0215],
         ...,
         [-0.0253, -0.0214,  0.0360,  ...,  0.0319,  0.0148,  0.0093],
         [-0.0251, -0.0213,  0.0360,  ...,  0.0315,  0.0152,  0.0090],
         [-0.0250, -0.0213,  0.0360,  ...,  0.0312,  0.0155,  0.0089]],

        [[ 0.0078,  0.0028,  0.0060,  ..., -0.0500,  0.0062,  0.0356],
         [-0.0165, -0.0188,  0.0092,  ..., -0.0068,  0.0109,  0.0224],
         [-0.0201, -0.0258,  0.0172,  ...,  0.0131,  0.0116,  0.0183],
         ...,
         [-0.0249, -0.0228,  0.0348,  ...,  0.0305,  0.0157,  0.0099],
        

17-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0110, -0.1071,  0.0352,  ..., -0.0218, -0.0109,  0.0029],
         [-0.0226, -0.0724,  0.0375,  ...,  0.0109,  0.0001,  0.0077],
         [-0.0250, -0.0502,  0.0353,  ...,  0.0255,  0.0064,  0.0102],
         ...,
         [-0.0254, -0.0232,  0.0356,  ...,  0.0317,  0.0159,  0.0093],
         [-0.0252, -0.0224,  0.0357,  ...,  0.0314,  0.0160,  0.0091],
         [-0.0250, -0.0219,  0.0359,  ...,  0.0311,  0.0160,  0.0090]],

        [[ 0.0407,  0.0077,  0.0147,  ..., -0.0186,  0.0409,  0.0523],
         [ 0.0207, -0.0120,  0.0148,  ..., -0.0004,  0.0110,  0.0369],
         [ 0.0008, -0.0176,  0.0215,  ...,  0.0156,  0.0070,  0.0311],
         ...,
         [-0.0248, -0.0209,  0.0344,  ...,  0.0297,  0.0147,  0.0108],
        

20-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[ 0.0117,  0.0164,  0.0803,  ..., -0.0122,  0.0247,  0.0692],
         [-0.0122, -0.0093,  0.0533,  ...,  0.0037,  0.0145,  0.0429],
         [-0.0203, -0.0177,  0.0440,  ...,  0.0153,  0.0128,  0.0311],
         ...,
         [-0.0257, -0.0215,  0.0363,  ...,  0.0305,  0.0155,  0.0105],
         [-0.0255, -0.0214,  0.0361,  ...,  0.0306,  0.0157,  0.0098],
         [-0.0253, -0.0213,  0.0361,  ...,  0.0306,  0.0158,  0.0093]],

        [[-0.0774,  0.0107, -0.0025,  ...,  0.0085, -0.0412,  0.0393],
         [-0.0510, -0.0027,  0.0067,  ...,  0.0204, -0.0073,  0.0272],
         [-0.0370, -0.0106,  0.0189,  ...,  0.0247,  0.0039,  0.0213],
         ...,
         [-0.0251, -0.0213,  0.0361,  ...,  0.0311,  0.0155,  0.0097],
        

torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0155, -0.0493,  0.0293,  ..., -0.0167,  0.0347,  0.0042],
         [-0.0173, -0.0459,  0.0239,  ...,  0.0077,  0.0245,  0.0104],
         [-0.0199, -0.0371,  0.0305,  ...,  0.0199,  0.0162,  0.0151],
         ...,
         [-0.0248, -0.0227,  0.0362,  ...,  0.0310,  0.0148,  0.0103],
         [-0.0249, -0.0221,  0.0362,  ...,  0.0309,  0.0151,  0.0097],
         [-0.0249, -0.0217,  0.0361,  ...,  0.0308,  0.0154,  0.0093]],

        [[-0.0926, -0.0309,  0.0251,  ...,  0.0555,  0.0493, -0.0121],
         [-0.0642, -0.0230,  0.0318,  ...,  0.0460,  0.0348, -0.0002],
         [-0.0486, -0.0225,  0.0342,  ...,  0.0382,  0.0271,  0.0061],
         ...,
         [-0.0267, -0.0219,  0.0363,  ...,  0.0318,  0.0175,  0.0089],
         [-0.0260, -0.0217,  0.0363,  ...,  0.0

torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0120,  0.0013,  0.0261,  ..., -0.0077,  0.0156,  0.0848],
         [-0.0233, -0.0133,  0.0182,  ...,  0.0199,  0.0130,  0.0504],
         [-0.0260, -0.0192,  0.0240,  ...,  0.0263,  0.0125,  0.0314],
         ...,
         [-0.0256, -0.0216,  0.0360,  ...,  0.0303,  0.0156,  0.0102],
         [-0.0253, -0.0215,  0.0361,  ...,  0.0303,  0.0158,  0.0096],
         [-0.0251, -0.0214,  0.0362,  ...,  0.0304,  0.0158,  0.0092]],

        [[-0.0448, -0.0522, -0.0072,  ..., -0.0450,  0.0204,  0.0210],
         [-0.0383, -0.0412,  0.0155,  ..., -0.0047,  0.0163,  0.0152],
         [-0.0331, -0.0342,  0.0264,  ...,  0.0135,  0.0142,  0.0140],
         ...,
         [-0.0253, -0.0221,  0.0361,  ...,  0.0305,  0.0155,  0.0097],
         [-0.0251, -0.0217,  0.0361,  ...,  0.0

29-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-4.6008e-02, -2.1944e-02,  4.0873e-02,  ..., -2.3467e-02,
           4.8328e-02, -8.9022e-03],
         [-3.3782e-02, -2.7676e-02,  3.7890e-02,  ...,  5.7024e-03,
           3.4707e-02,  1.5498e-02],
         [-2.7372e-02, -2.8059e-02,  3.4839e-02,  ...,  2.0351e-02,
           2.7233e-02,  1.9897e-02],
         ...,
         [-2.4974e-02, -2.2345e-02,  3.5520e-02,  ...,  3.0808e-02,
           1.6900e-02,  9.7383e-03],
         [-2.4944e-02, -2.1887e-02,  3.5753e-02,  ...,  3.0799e-02,
           1.6472e-02,  9.1919e-03],
         [-2.4898e-02, -2.1600e-02,  3.5892e-02,  ...,  3.0763e-02,
           1.6201e-02,  8.9133e-03]],

        [[-5.6132e-02,  7.2714e-03,  2.5638e-02,  ..., -2.9602e-02,
           2.1966e-02,  1.6580e-02

torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-0.0402, -0.0664,  0.0581,  ...,  0.0384,  0.0373,  0.0078],
         [-0.0317, -0.0525,  0.0466,  ...,  0.0340,  0.0275,  0.0193],
         [-0.0297, -0.0431,  0.0433,  ...,  0.0342,  0.0236,  0.0229],
         ...,
         [-0.0259, -0.0232,  0.0368,  ...,  0.0314,  0.0160,  0.0102],
         [-0.0255, -0.0223,  0.0365,  ...,  0.0311,  0.0159,  0.0095],
         [-0.0253, -0.0218,  0.0363,  ...,  0.0309,  0.0158,  0.0091]],

        [[-0.1020,  0.0197,  0.0158,  ..., -0.0463,  0.0284, -0.0037],
         [-0.0718, -0.0040,  0.0109,  ..., -0.0010,  0.0251, -0.0078],
         [-0.0501, -0.0152,  0.0184,  ...,  0.0183,  0.0222, -0.0010],
         ...,
         [-0.0261, -0.0218,  0.0347,  ...,  0.0311,  0.0165,  0.0086],
         [-0.0255, -0.0216,  0.0353,  ...,  0.0

34-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[-7.5338e-02,  7.5365e-04, -3.6323e-02,  ...,  3.6067e-02,
           4.4927e-02, -1.2034e-02],
         [-5.7321e-02, -1.2718e-02, -7.2274e-03,  ...,  3.2428e-02,
           2.3120e-02,  1.3782e-03],
         [-4.3358e-02, -1.9086e-02,  1.2699e-02,  ...,  3.2775e-02,
           1.6289e-02,  9.7426e-03],
         ...,
         [-2.6055e-02, -2.1743e-02,  3.5155e-02,  ...,  3.1261e-02,
           1.5453e-02,  9.7034e-03],
         [-2.5560e-02, -2.1565e-02,  3.5625e-02,  ...,  3.1082e-02,
           1.5633e-02,  9.2889e-03],
         [-2.5256e-02, -2.1433e-02,  3.5857e-02,  ...,  3.0943e-02,
           1.5729e-02,  9.0363e-03]],

        [[-7.1348e-02, -8.2351e-03,  6.7642e-02,  ..., -2.4184e-02,
           5.2693e-02,  2.1775e-02

37-torch.Size([16, 3, 448, 448]) 16 16
torch.Size([16, 768]), torch.Size([16, 768]). torch.Size([16, 768])
initial inputs:torch.Size([16, 1, 37]), hidden_state:torch.Size([1, 16, 768]), cell_state:torch.Size([1, 16, 768])
torch.Size([16, 10, 37])
out: tensor([[[ 0.0137, -0.0173,  0.0567,  ...,  0.0589, -0.0179,  0.0488],
         [-0.0036, -0.0231,  0.0352,  ...,  0.0441, -0.0053,  0.0360],
         [-0.0126, -0.0241,  0.0334,  ...,  0.0359,  0.0034,  0.0272],
         ...,
         [-0.0245, -0.0218,  0.0365,  ...,  0.0314,  0.0153,  0.0103],
         [-0.0247, -0.0216,  0.0364,  ...,  0.0312,  0.0156,  0.0096],
         [-0.0248, -0.0214,  0.0363,  ...,  0.0310,  0.0157,  0.0093]],

        [[-0.0650, -0.0442, -0.0009,  ..., -0.0229,  0.0135,  0.0660],
         [-0.0524, -0.0291,  0.0190,  ...,  0.0035,  0.0185,  0.0396],
         [-0.0422, -0.0243,  0.0313,  ...,  0.0168,  0.0189,  0.0256],
         ...,
         [-0.0264, -0.0212,  0.0369,  ...,  0.0307,  0.0165,  0.0093],
        

40-torch.Size([16, 3, 448, 448]) 16 16


KeyboardInterrupt: 

![image.png](attachment:image.png)

In [44]:
test=torch.zeros(8, 10, 37)
test[0,:,:].shape

torch.Size([10, 37])

## nickCLIP

In [54]:
class nickCLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder, decoder, device):
        super().__init__()
        self.image_encoder=image_encoder.to(device)
        self.text_encoder=text_encoder.to(device)
        self.decoder=decoder.to(device)
        
    def forward(self, image, description):
        img_embed=self.image_encoder(image).unsqueeze(0)
        sen_embed=self.text_encoder(description).unsqueeze(0)
        merged_embed=img_embed+sen_embed
        
        batch_size=image.shape[0]
        target_len=10
        initial_c=torch.zeros(1,batch_size,768).to(device)
        start_token = torch.zeros(batch_size,1,37).to(device) # start token??
        
        inputs=start_token
        hidden_state=merged_embed
        cell_state=initial_c
        
        outputs=torch.zeros(batch_size, target_len, 37)
#         for b in range(batch_size):
        for i in range(target_len):
            output,(hidden_state, cell_state)=decoder(inputs,(hidden_state, cell_state))
            inputs=output
#             print(f"{i}-output_shape:{output.shape}")
#             print(f"{outputs[:,i,:].shape}, {output.shape}" )
            outputs[:,i,:]=output.squeeze(1)
        
        
        return outputs
        

In [55]:
Image_Enc = Image_Encoder()
Text_Enc=Text_Encoder(device=device)
Decoder=lstm_decoder(input_size=37, hidden_size=768)
model=nickCLIP(image_encoder=Image_Enc, text_encoder=Text_Enc, decoder=Decoder, device=device)
model.to(device)

nickCLIP(
  (image_encoder): Image_Encoder(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d

## Train

In [16]:
# img, des, label
def collate_fn(batch):
    image_list = []
    des_list = []
    label_list = []
    
    for a,b,c in batch:
        image_list.append(a)
        des_list.append(b)
        label_list.append(c)

    return torch.stack(image_list, dim=0), des_list, label_list

In [57]:
#One-hot dict
def alp_to_mat(string_list,batch_size):
    one_hot_dict={'a':0,'b':1,'c':2,'d':3,'e':4,'f':5,'g':6,'h':7,'i':8,'j':9,'k':10,'l':11,'m':12,'n':13,'o':14,'p':15,'q':16,'r':17,'s':18,'t':19,'u':20,'v':21,'w':22,'x':23,'y':24,'z':25,'0':26,'1':27,'2':28,'3':29,'4':30,'5':31,'6':32,'7':33,'8':34,'9':35,' ':36}
    alp_to_num_list=[]
    batch_size=len(string_list)
    alp_to_mat_list=torch.zeros((batch_size,len(string_list[0]),37)) #(B,name_len,one_hot_len)
    for j in range(batch_size):
        for i in range(len(string_list[j])):
            char=string_list[j][i]
            if(char in one_hot_dict.keys()):
                pass
            else:
                char=' '
                
            alp_to_num_list.append(one_hot_dict[char])
            mat=torch.zeros(37)
    #         print(len(string_list))
    #         print(mat.shape)
            mat[one_hot_dict[char]]=1
#             print(f"{alp_to_mat_list.shape}-{mat.shape}")
            alp_to_mat_list[j,i,:]=mat
    #         result=torch.Tensor(alp_to_mat_list)
    return alp_to_mat_list
    


# test=[list("heloo     "),list("     fffff"),list("kdkdkkdkdk"),list("abcdefghij")]
# result=alp_to_mat(test,len(test))
# result.shape

## Train

In [58]:
def train_one_epoch(dataloaders, model, criterion, optimizer, device):
    train_loss = defaultdict(float)
    val_loss = defaultdict(float)
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
            
        running_loss = defaultdict(float)
        for index, batch in enumerate(dataloaders[phase]):
            images = batch[0].to(device)
            description = batch[1]
            label = batch[2]
            
            target=alp_to_mat(label,len(label))
            
            with torch.set_grad_enabled(phase == "train"):
                predictions = model(images, description)
                
            target_for_loss=target.view(-1,370).to(device)
            predictions_for_loss=predictions.view(-1,370).to(device)

            loss=criterion(predictions_for_loss, target_for_loss, torch.Tensor(predictions_for_loss.size(0)).cuda().fill_(1.0))
            
            if phase == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss["total_loss"] += loss.item()
                
                train_loss["total_loss"] += loss.item()
                
                if (index > 0) and (index % VERBOSE_FREQ) == 0:
                    text = f"<<<iteration:[{index}/{len(dataloaders[phase])}] - "
                    for k, v in running_loss.items():
                        text += f"{k}: {v/VERBOSE_FREQ:.4f}  "
                        running_loss[k] = 0.
                    print(text)
            else:
                val_loss["total_loss"] += loss.item()
        
    for k in train_loss.keys():
        train_loss[k] /= len(dataloaders["train"])
        val_loss[k] /= len(dataloaders["val"])
            
    return train_loss, val_loss

In [59]:
PATH="/workspace/team2/data/filter_50000/"

is_cuda = True

IMAGE_SIZE = 448
BATCH_SIZE = 16
VERBOSE_FREQ = 20
LR=0.001

IMAGE_ENC="RESNET34"
TEXT_ENC="BERT"
DECODER="LSTM"
num_epochs = 100
# DEVICE = torch.device('cuda' if torch.cuda.is_available and is_cuda else 'cpu')

dataloaders = build_dataloader(PATH=PATH, batch_size=BATCH_SIZE)


Image_Enc = Image_Encoder()
Text_Enc=Text_Encoder(device=device)
Decoder=lstm_decoder(input_size=37, hidden_size=768)
model=nickCLIP(image_encoder=Image_Enc, text_encoder=Text_Enc, decoder=Decoder, device=device)
model.to(device)

criterion = torch.nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

trainset:43864 validset:11340


In [18]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="nickclip_baseline",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "batch_size": BATCH_SIZE,
    "image encoder":IMAGE_ENC,
    "test encoder":TEXT_ENC,
    "decoder":DECODER,
    "dataset": PATH,
    "epochs": num_epochs,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mgomduribo[0m ([33murp[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
best_epoch = 0
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss, val_loss = train_one_epoch(dataloaders, model, criterion, optimizer, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
#     wandb.log({"Train Loss": train_loss['total_loss'],
#                "Val Loss": val_loss['total_loss'],})
    print(f"\nepoch:{epoch+1}/{num_epochs} - Train Loss: {train_loss['total_loss']:.4f}, Val Loss: {val_loss['total_loss']:.4f}\n")
    
    if (epoch+1) % 10 == 0:
        save_model(model.state_dict(), f'model_{epoch+1}.pth', save_dir=f"./trained_model/{IMAGE_ENC}_{TEXT_ENC}_{DECODER}")
wandb.finish()

<<<iteration:[20/2742] - total_loss: 0.8908  
<<<iteration:[40/2742] - total_loss: 0.8179  
<<<iteration:[60/2742] - total_loss: 0.8105  
<<<iteration:[80/2742] - total_loss: 0.8076  
<<<iteration:[100/2742] - total_loss: 0.8065  
<<<iteration:[120/2742] - total_loss: 0.8028  
<<<iteration:[140/2742] - total_loss: 0.8068  
<<<iteration:[160/2742] - total_loss: 0.8052  
<<<iteration:[180/2742] - total_loss: 0.7988  
<<<iteration:[200/2742] - total_loss: 0.7983  
<<<iteration:[220/2742] - total_loss: 0.8050  
<<<iteration:[240/2742] - total_loss: 0.8014  
<<<iteration:[260/2742] - total_loss: 0.8027  
<<<iteration:[280/2742] - total_loss: 0.7973  
<<<iteration:[300/2742] - total_loss: 0.7948  
<<<iteration:[320/2742] - total_loss: 0.8000  
<<<iteration:[340/2742] - total_loss: 0.7926  
<<<iteration:[360/2742] - total_loss: 0.7940  
<<<iteration:[380/2742] - total_loss: 0.7946  
<<<iteration:[400/2742] - total_loss: 0.7971  
<<<iteration:[420/2742] - total_loss: 0.7965  
<<<iteration:[440



<<<iteration:[2100/2742] - total_loss: 0.7917  
<<<iteration:[2120/2742] - total_loss: 0.7864  
<<<iteration:[2140/2742] - total_loss: 0.7829  
<<<iteration:[2160/2742] - total_loss: 0.7892  
<<<iteration:[2180/2742] - total_loss: 0.7880  
<<<iteration:[2200/2742] - total_loss: 0.7857  
<<<iteration:[2220/2742] - total_loss: 0.7851  
<<<iteration:[2240/2742] - total_loss: 0.7889  
<<<iteration:[2260/2742] - total_loss: 0.7887  
<<<iteration:[2280/2742] - total_loss: 0.7858  
<<<iteration:[2300/2742] - total_loss: 0.7842  
<<<iteration:[2320/2742] - total_loss: 0.7866  
<<<iteration:[2340/2742] - total_loss: 0.7845  
<<<iteration:[2360/2742] - total_loss: 0.7879  
<<<iteration:[2380/2742] - total_loss: 0.7884  
<<<iteration:[2400/2742] - total_loss: 0.7844  
<<<iteration:[2420/2742] - total_loss: 0.7831  
<<<iteration:[2440/2742] - total_loss: 0.7874  
<<<iteration:[2460/2742] - total_loss: 0.7886  
<<<iteration:[2480/2742] - total_loss: 0.7878  
<<<iteration:[2500/2742] - total_loss: 0



<<<iteration:[1140/2742] - total_loss: 0.7869  
<<<iteration:[1160/2742] - total_loss: 0.7853  
<<<iteration:[1180/2742] - total_loss: 0.7851  
<<<iteration:[1200/2742] - total_loss: 0.7860  
error: image name=>charliesheen_50.jpg des name=>charliesheen_50.txt label name=>charliesheen_50.txt
<<<iteration:[1220/2742] - total_loss: 0.7868  
<<<iteration:[1240/2742] - total_loss: 0.7902  
<<<iteration:[1260/2742] - total_loss: 0.7871  
<<<iteration:[1280/2742] - total_loss: 0.7850  
<<<iteration:[1300/2742] - total_loss: 0.7831  
<<<iteration:[1320/2742] - total_loss: 0.7889  
<<<iteration:[1340/2742] - total_loss: 0.7846  
<<<iteration:[1360/2742] - total_loss: 0.7862  
<<<iteration:[1380/2742] - total_loss: 0.7886  
<<<iteration:[1400/2742] - total_loss: 0.7829  
<<<iteration:[1420/2742] - total_loss: 0.7834  
<<<iteration:[1440/2742] - total_loss: 0.7841  
<<<iteration:[1460/2742] - total_loss: 0.7824  
<<<iteration:[1480/2742] - total_loss: 0.7903  
<<<iteration:[1500/2742] - total_lo

## Inference

In [82]:
def load_model(ckpt_path, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model=nickCLIP(image_encoder=Image_Enc, text_encoder=Text_Enc, decoder=Decoder, device=device)
    model.load_state_dict(checkpoint)
    model = model.to(device)
    model.eval()
    return model

In [83]:
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )

In [84]:
ckpt_path="/workspace/team2/yb_workspace/nickCLIP/experiments/trained_model/RESNET34_BERT_LSTM/model_90.pth"
model = load_model(ckpt_path, device)

In [85]:
root='/workspace/team2/data/filter_50000/'
test_dataset=Dataset(root=root, phase="test", transformer=transformer)
test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
len(test_dataset)

8464

In [87]:
def idx_to_c(idx):
    one_hot_dict={'a':0,'b':1,'c':2,'d':3,'e':4,'f':5,'g':6,'h':7,'i':8,'j':9,'k':10,'l':11,'m':12,'n':13,'o':14,'p':15,'q':16,'r':17,'s':18,'t':19,'u':20,'v':21,'w':22,'x':23,'y':24,'z':25,'0':26,'1':27,'2':28,'3':29,'4':30,'5':31,'6':32,'7':33,'8':34,'9':35,' ':36}
    idx_to_c = {v: k for k, v in one_hot_dict.items()}
    return idx_to_c[idx]
# print(idx_to_c)
@interact(index=(0, len(test_dataset)-1))
def show_sample(index):
    img, des, label = test_dataset[index]
#     image=img['image'].permute(1,2,0).numpy()
    image=img.unsqueeze(0).to(device)
    print(f"description: {des}")
    print(f"label: {''.join(label)}")
    
    predictions = model(images, des)
    result_torch=torch.argmax(predictions, dim=2)
    result_list=list(result_torch[0])
    result=list(map(int, result_list))
    result_str=''.join(list(map(idx_to_c, result)))
    print(predictions[0][1])
    print(result)
    print(result_str)
#     print
    
    
    image=image.squeeze(0).permute(1,2,0).cpu().numpy()
    plt.imshow(image)
    plt.axis('off')
    plt.show()

interactive(children=(IntSlider(value=4231, description='index', max=8463), Output()), _dom_classes=('widget-i…