In [40]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
import logging
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Datasets

In [54]:
class Dataset():
    def __init__(self, root, phase, transformer=None):
        self.root=root
        self.phase=phase
        self.transformer=transformer
        self.image_list=sorted(os.listdir(root+"image/"+phase))
        self.des_list=sorted(os.listdir(root+"description/"+phase))
        self.label_list=sorted(os.listdir(root+"label/"+phase))
        
    def __getitem__(self, index):
        img, des, label = self.get_data(index)
        return img, des, label
        
    def __len__(self, ):
        return len(self.image_list)

    def get_data(self, index):
        # label
        label_file_name=self.label_list[index]
        lab_f=open(self.root+"label/"+self.phase+"/"+label_file_name, "r")
        label=lab_f.read()

        # description
        des_file_name=self.des_list[index]
        des_f=open(self.root+"description/"+self.phase+"/"+des_file_name, "r")
        des_text=des_f.read()
#         des=des_text.split(" ")
        des=des_text

        # image
        img_file_name=self.image_list[index]
        image=cv2.imread(self.root+"image/"+self.phase+"/"+img_file_name)
        img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if(self.transformer!=None):
            transformed_img=self.transformer(image=img)
            img=transformed_img
        
        return img, des, label
        
    

In [55]:
IMAGE_SIZE=448
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ]
    )

In [56]:
root='/workspace/team2/data/nickData/'
train_dataset=Dataset(root=root, phase="train", transformer=transformer)
len(train_dataset)

In [59]:
@interact(index=(0, len(train_dataset)-1))
def show_sample(index):
    img, des, label = train_dataset[index]
    image=img['image'].permute(1,2,0).numpy()
    print(f'desciption: {des}')
    print(f'label: {label}')
    print(image.shape)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

interactive(children=(IntSlider(value=49, description='index', max=99), Output()), _dom_classes=('widget-inter…

## MODELs
 ![Untitled](../img/nickCLIP_arch.png)

### Image Encoder

In [73]:
class Image_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        resnet = torchvision.models.resnet34(pretrained = True)
        layers = [m for m in resnet.children()]
        
        self.backbone = nn.Sequential(*layers[:-2]) 
        self.head = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, padding=0,bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=32, out_channels=4, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(4),
                nn.ReLU(inplace=True),
                nn.Flatten(),
                nn.Linear(in_features=784, out_features=768)
            
            )
        
#         self.head = nn.Sequential(
#                 nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, padding=0,bias=False),
#                 nn.BatchNorm2d(256),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(128),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(32),
#                 nn.ReLU(inplace=True),
#                 nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1,bias=False),
#                 nn.BatchNorm2d(1),
#                 nn.ReLU(inplace=True),
#                 nn.Flatten()
#             )
    def forward(self, x):
        out = self.backbone(x)
        out = self.head(out) # final output=> (1, 196)
        return out

In [74]:
Image_Enc = Image_Encoder()
Image_Enc.to(device)
torchsummary.summary(Image_Enc, (3,448,448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
       BasicBlock-11         [-1, 64, 112, 112]               0
           Conv2d-12         [-1, 64, 112, 112]          36,864
      BatchNorm2d-13         [-1, 64, 112, 112]             128
             ReLU-14         [-1, 64, 1

### Text Encoder

In [75]:
class Text_Encoder(nn.Module):
    def __init__(self, device, pretrained='bert-base-uncased'):
        super().__init__()
        self.pretrained=pretrained
        self.device=device
        self.BERT = BertModel.from_pretrained(self.pretrained)
    
    def preprocess(self, text):
        tokenizer = BertTokenizer.from_pretrained(self.pretrained)
        marked_text = "[CLS] " + text + " [SEP]"
        tokenized_text = tokenizer.tokenize(marked_text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        segments_ids = [1] * len(tokenized_text)
        
        tokens_tensor = torch.tensor([indexed_tokens]).to(self.device)
        segments_tensors = torch.tensor([segments_ids]).to(self.device)
        
        return tokens_tensor, segments_tensors
        
    def postprocess(self, encoded_layers):
#         token_embeddings = torch.stack(encoded_layers, dim=0)
#         token_embeddings = token_embeddings.permute(1,0,2)
        token_vecs = encoded_layers[11][0]
        sentence_embedding = torch.mean(token_vecs, dim=0)
        
        return sentence_embedding
    
    def forward(self, x):
        tokens_tensor, segments_tensors = self.preprocess(x)
        
        encoded_layers, _ = self.BERT(tokens_tensor, segments_tensors)
        
        sentence_embedding = self.postprocess(encoded_layers)
        
        return sentence_embedding

In [48]:
Text_Enc=Text_Encoder(device=device)

In [49]:
Text_Enc.to(device)

Text_Encoder(
  (BERT): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): FusedLayerNorm(torch.Size([768]), eps

In [86]:
@interact(index=(0, len(train_dataset)-1))
def show_sample(index):
    img, des, label = train_dataset[index]
#     image=img['image'].permute(1,2,0).numpy()
    image=img['image'].unsqueeze(0).to(device)
    print(image.shape)
    img_embed=Image_Enc(image)
    sen_embed=Text_Enc(des)
    print(f'desciption: {des}')
    print(f'label: {label}')
    print(f"img shape:{image.shape}")
    print(f'sen_emb shape:{sen_embed.shape}')
    print(f'img_embed shape:{img_embed.shape}')
    image=image.squeeze(0).permute(1,2,0).cpu().numpy()
    plt.imshow(image)
    plt.axis('off')
    plt.show()

interactive(children=(IntSlider(value=49, description='index', max=99), Output()), _dom_classes=('widget-inter…

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, image_encoder, text_encoder):
        super().__init__()
        pass
    
    def forward(self, x):
        pass
        
        return out