In [3]:
import torch
import torch.nn as nn
from einops import rearrange, repeat
from model.attention import Attention

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import numpy as np
import os
import json
from torch.utils.data.dataset import Dataset

In [5]:
from model.patch_embed import PatchEmbedding
from model.transformer_layer import TransformerLayer

In [6]:
import yaml
import argparse
import os
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader

In [7]:
import random
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [8]:
import cv2

ImportError: libGL.so.1: cannot open shared object file: No such file or directory

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
device

device(type='cuda')

In [11]:

def get_random_crop(image, crop_h, crop_w):
    h, w = image.shape[:2]
    max_x = w - crop_w
    max_y = h - crop_h
    
    x = np.random.randint(0, max_x)
    y = np.random.randint(0, max_y)
    crop = image[y: y + crop_h, x: x + crop_w, :]
    return crop


def get_center_crop(image):
    h, w = image.shape[:2]
    if h > w:
        return image[(h - w) // 2:-(h - w) // 2, :, :]
    else:
        return image[:, (w - h) // 2:-(w - h) // 2, :]


class MnistDataset(Dataset):
    r"""
    Minimal image dataset where we take mnist images
    add a texture background
    change the color of the digit.
    Model trained on this dataset is then required to predict the below 3 values
    1. Class of texture
    2. Class of number
    3. R, G, B values (0-1) of the digit color
    """
    def __init__(self, split, config, im_h=224, im_w=224):
        self.split = split
        self.db_root = config['root_dir']
        self.im_h = im_h
        self.im_w = im_w
        
        imdb = json.load(open(os.path.join(self.db_root,  'imdb.json')))
        self.im_info = imdb['{}_data'.format(split)]
        self.texture_to_idx = imdb['texture_classes_index']
        self.idx_to_texture = {v:k for k,v in self.texture_to_idx.items()}
        
    def __len__(self):
        return len(self.im_info)
    
    def __getitem__(self, index):
        entry = self.im_info[index]
        digit_cls = int(entry['digit_name'])
        digit_im = cv2.imread(os.path.join(self.db_root, entry['digit_image']))
        digit_im = cv2.cvtColor(digit_im, cv2.COLOR_BGR2RGB)
        digit_im = cv2.resize(digit_im, (self.im_h, self.im_w))
        
        # Discretize mnist images to be either 0 or 1
        digit_im[digit_im > 50] = 255
        digit_im[digit_im <= 50] = 0
        mask_val = (digit_im > 0).astype(np.float32)
        digit_im = np.concatenate((digit_im[:, :, 0][..., None] * float(entry['color_r']),
                                   digit_im[:, :, 1][..., None] * float(entry['color_g']),
                                   digit_im[:, :, 2][..., None] * float(entry['color_b'])), axis=-1)
        im = cv2.imread(os.path.join(self.db_root, entry['texture_image']))
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        if self.split == 'train':
            im = get_random_crop(im, self.im_h, self.im_w)
        else:
            im = get_center_crop(im)
            im = cv2.resize(im, (self.im_h, self.im_w))
        out_im = mask_val * digit_im + (1 - mask_val) * im
        im_tensor = torch.from_numpy(out_im).permute((2, 0, 1))
        im_tensor = 2 * (im_tensor / 255) - 1
        return {
            "image" : im_tensor,
            "texture_cls" : self.texture_to_idx[entry['texture_name']],
            "number_cls" : digit_cls,
            "color":torch.as_tensor([float(entry['color_r']),
                                     float(entry['color_g']),
                                      float(entry['color_b'])])
        }


In [12]:
class PatchEmbedding(nn.Module):
    r"""
    Layer to take in the input image and do the following:
        1.  Transform grid of image into a sequence of patches.
            Number of patches are decided based on image height,width and
            patch height, width.
        2. Add cls token to the above created sequence of patches in the
            first position
        3. Add positional embedding to the above sequence(after adding cls)
        4. Dropout if needed
    """
    def __init__(self, config):
        super().__init__()
        # Example configuration
        #   Image c,h,w : 3, 224, 224
        #   Patch h,w : 16, 16
        image_height = config['image_height']
        image_width = config['image_width']
        im_channels = config['im_channels']
        emb_dim = config['emb_dim']
        patch_embd_drop = config['patch_emb_drop']
        
        self.patch_height = config['patch_height']
        self.patch_width = config['patch_width']
        
        # Compute number of patches for positional parameters initialization
        #   num_patches = num_patches_h * num_patches_w
        #   num_patches = 224/16 * 224/16
        #   num_patches = 196
        num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)
        
        # This is the input dimension of the patch_embed layer
        # After patchifying the 224, 224, 3 image will be
        # num_patches x patch_h x patch_w x 3
        # Which will be 196 x 16 x 16 x 3
        # Hence patch dimension = 16 * 16 * 3
        patch_dim = im_channels * self.patch_height * self.patch_width
        
        self.patch_embed = nn.Sequential(
            # This pre and post layer norm speeds up convergence
            # Comment them if you want pure vit implementation
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, emb_dim),
            nn.LayerNorm(emb_dim)
        )
        
        # Positional information needs to be added to cls as well so 1+num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.randn(emb_dim))
        self.patch_emb_dropout = nn.Dropout(patch_embd_drop)
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        # This is doing the B, 3, 224, 224 -> (B, num_patches, patch_dim) transformation
        # B, 3, 224, 224 -> B, 3, 14*16, 14*16
        # B, 3, 14*16, 14*16 -> B, 3, 14, 16, 14, 16
        # B, 3, 14, 16, 14, 16 -> B, 14, 14, 16, 16, 3
        #  B, 14*14, 16*16*3 - > B, num_patches, patch_dim
        out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',
                      ph=self.patch_height,
                      pw=self.patch_width)
        out = self.patch_embed(out)
        
        # Add cls
        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=batch_size)
        out = torch.cat((cls_tokens, out), dim=1)
        
        # Add position embedding and do dropout
        out += self.pos_embed
        out = self.patch_emb_dropout(out)
        
        return out

    

In [13]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config['n_heads']
        self.head_dim = config['head_dim']
        self.emb_dim = config['emb_dim']
        self.drop_prob = config['dropout'] if 'dropout' in config else 0.0
        self.att_dim = self.n_heads * self.head_dim
        
        self.qkv_proj = nn.Linear(self.emb_dim, 3 * self.att_dim, bias=False)
        self.output_proj = nn.Sequential(
            nn.Linear(self.att_dim, self.emb_dim),
            nn.Dropout(self.drop_prob))

        self.attn_dropout = nn.Dropout(self.drop_prob)

    def forward(self, x):
        
        #  Converting to Attention Dimension
        ######################################################
        # Batch Size x Number of Patches x Dimension
        B, N = x.shape[:2]
        # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim)
        # qkv -> Batch Size x Number of Patches x (3* Attention Dimension)
        # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension
        q, k ,v = self.qkv_proj(x).split(self.att_dim, dim=-1)
        # Batch Size x Number of Patches x Attention Dimension
        # -> Batch Size x Number of Patches x (Heads * Head Dimension)
        # -> Batch Size x Number of Patches x (Heads * Head Dimension)
        # -> Batch Size x Heads x Number of Patches x Head Dimension
        # -> B x H x N x Head Dimension
        q = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
        k = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
        v = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
        #########################################################
        
        # Compute Attention Weights
        #########################################################
        # B x H x N x Head Dimension @ B x H x Head Dimension x N
        # -> B x H x N x N
        att = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**(-0.5))
        att = torch.nn.functional.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        #########################################################
        
        # Weighted Value Computation
        #########################################################
        #  B x H x N x N @ B x H x N x Head Dimension
        # -> B x H x N x Head Dimension
        out = torch.matmul(att, v)
        #########################################################
        
        # Converting to Transformer Dimension
        #########################################################
        # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension)
        out = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)')
        #  B x N x Dimension
        out = self.output_proj(out)
        ##########################################################
        
        return out
    


In [14]:
class TransformerLayer(nn.Module):
    r"""
    Transformer block which is just doing the following
        1. LayerNorm followed by Attention
        2. LayerNorm followed by Feed forward Block
        Both these also have residuals added to them
    """
    def __init__(self, config):
        super().__init__()
        emb_dim = config['emb_dim']
        ff_hidden_dim = config['ff_dim'] if 'ff_dim' in config else 4*emb_dim
        ff_drop_prob = config['ff_drop'] if 'ff_drop' in config else 0.0
        self.att_norm = nn.LayerNorm(emb_dim)
        self.attn_block = Attention(config)
        self.ff_norm = nn.LayerNorm(emb_dim)
        
        self.ff_block = nn.Sequential(
            nn.Linear(emb_dim, ff_hidden_dim),
            nn.GELU(),
            nn.Dropout(ff_drop_prob),
            nn.Linear(ff_hidden_dim, emb_dim),
            nn.Dropout(ff_drop_prob)
        )
        
    def forward(self, x):
        out = x
        out = out + self.attn_block(self.att_norm(out))
        out = out + self.ff_block(self.ff_norm(out))
        return out


In [15]:
class VIT(nn.Module):
    def __init__(self, config):
        super().__init__()
        n_layers = config['n_layers']
        emb_dim = config['emb_dim']
        num_classes = config['num_classes']
        self.patch_embed_layer = PatchEmbedding(config)
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.fc_number = nn.Linear(emb_dim, num_classes)
        
    def forward(self, x):
        # Patchify and add CLS token
        out = self.patch_embed_layer(x)
        
        # Go through the transformer layers
        for layer in self.layers:
            out = layer(out)
        out = self.norm(out)
        
        # Compute logits
        return self.fc_number(out[:, 0])


In [21]:




def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer):
    r"""
    Method to run the training for one epoch.
    :param epoch_idx: iteration number of current epoch
    :param model: Transformer model
    :param mnist_loader: Data loder for mnist
    :param optimizer: optimizer to be used taken from config
    :return:
    """
    losses = []
    criterion = torch.nn.CrossEntropyLoss()
    for data in tqdm(mnist_loader):
        im = data['image'].float().to(device)
        number_cls = data['number_cls'].long().to(device)
        optimizer.zero_grad()
        model_output = model(im)
        loss = criterion(model_output, number_cls)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    print('Finished epoch: {} | Number Loss : {:.4f}'.
          format(epoch_idx + 1,
                 np.mean(losses)))
    return np.mean(losses)


# def train(args):
#     #  Read the config file
#     ######################################
#     with open(args.config_path, 'r') as file:
#         try:
#             config = yaml.safe_load(file)
#         except yaml.YAMLError as exc:
#             print(exc)
#     print(config)
#     #######################################
    
#     # Set the desired seed value
#     ######################################
#     seed = config['train_params']['seed']
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)
#     if device == 'cuda':
#         torch.cuda.manual_seed_all(args.seed)
#     #######################################
    
#     # Create the model and dataset
#     model = VIT(config['model_params']).to(device)
#     mnist = MnistDataset('train', config['dataset_params'],
#                          im_h=config['model_params']['image_height'],
#                          im_w=config['model_params']['image_width'])
#     mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=4)
#     num_epochs = config['train_params']['epochs']
#     optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
#     scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
    
#     # Create output directories
#     if not os.path.exists(config['train_params']['task_name']):
#         os.mkdir(config['train_params']['task_name'])
    
#     # Load checkpoint if found
#     if os.path.exists(os.path.join(config['train_params']['task_name'],
#                                    config['train_params']['ckpt_name'])):
#         print('Loading checkpoint')
#         model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
#                                                       config['train_params']['ckpt_name']), map_location=device))
#     best_loss = np.inf
    
#     for epoch_idx in range(num_epochs):
#         mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer)
#         scheduler.step(mean_loss)
#         # Simply update checkpoint if found better version
#         if mean_loss < best_loss:
#             print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
#             torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
#                                                         config['train_params']['ckpt_name']))
#             best_loss = mean_loss
#         else:
#             print('No Loss Improvement')

def train():
   

    # Set the desired seed value
    ######################################
    config = {
        'dataset_params': {
            'root_dir': 'data'
        },

        'model_params': {
            'n_heads': 8,
            'head_dim': 64,
            'emb_dim': 128,
            'attn_drop': 0.1,
            'ff_dim': 256,
            'ff_drop': 0.1,
            'n_layers': 6,
            'bg_classes': 44,
            'num_classes': 10,
            'image_height': 224,
            'image_width': 224,
            'patch_height': 16,
            'patch_width': 16,
            'patch_emb_drop': 0.1,
            'im_channels': 3
        },

        'train_params': {
            'task_name': 'default',
            'batch_size': 64,
            'epochs': 100,
            'lr': 0.001,
            'seed': 1111,
            'ckpt_name': 'vit_ckpt.pth'
        }
    }


    seed = config['train_params']['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # if device == 'cuda':
    #     torch.cuda.manual_seed_all(args.seed)
    #######################################
    
    # Create the model and dataset
    model = VIT(config['model_params']).to(device)
    mnist = MnistDataset('train', config['dataset_params'],
                         im_h=config['model_params']['image_height'],
                         im_w=config['model_params']['image_width'])
    mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=4)
    num_epochs = config['train_params']['epochs']
    optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
    
    # Create output directories
    if not os.path.exists(config['train_params']['task_name']):
        os.mkdir(config['train_params']['task_name'])
    
    # Load checkpoint if found
    if os.path.exists(os.path.join(config['train_params']['task_name'],
                                   config['train_params']['ckpt_name'])):
        print('Loading checkpoint')
        model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                      config['train_params']['ckpt_name']), map_location=device))
    best_loss = np.inf
    
    for epoch_idx in range(num_epochs):
        mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer)
        scheduler.step(mean_loss)
        # Simply update checkpoint if found better version
        if mean_loss < best_loss:
            print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
            torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
                                                        config['train_params']['ckpt_name']))
            best_loss = mean_loss
        else:
            print('No Loss Improvement')





In [22]:
train()

  0%|          | 0/938 [00:00<?, ?it/s]


error: Caught error in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_113/1451137293.py", line 48, in __getitem__
    digit_im = cv2.cvtColor(digit_im, cv2.COLOR_BGR2RGB)
cv2.error: OpenCV(4.5.3) /tmp/pip-req-build-f51eratu/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'



In [18]:

# if __name__ == '__main__':
#     parser = argparse.ArgumentParser(description='Arguments for vit training')
#     parser.add_argument('--config', dest='config_path',
#                         default='config/default.yaml', type=str)
#     args = parser.parse_args()
#     train(args)


usage: ipykernel_launcher.py [-h] [--config CONFIG_PATH]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-e2a93ee7-91e8-4f0e-adfe-94472b39ded7.json


SystemExit: 2

In [None]:

def get_accuracy(model, mnist_loader):
    r"""
    Method to get accuracy for number classification for trained model
    :param model:
    :param mnist_loader:
    :return:
    """
    num_total = 0.
    num_correct = 0.
    
    for data in tqdm(mnist_loader):
        im = data['image'].float().to(device)
        number_cls = data['number_cls'].long().to(device)
        model_output = model(im)
        pred_num_cls_idx = torch.argmax(model_output, dim=-1)
        num_total += pred_num_cls_idx.size(0)
        num_correct += torch.sum(pred_num_cls_idx == number_cls).item()
    num_accuracy = num_correct / num_total
    print('Number Accuracy : {:2f}'.format(num_accuracy))

   
def visualize_pos_embed(model):
    r"""
    Method to save the positional embeddings cosine similarity map
    Assumes number of patches to be 196
    :param model:
    :return:
    """
    # pos_embed = 1 x Num_patches+1 x D
    # Get indexes after CLS
    pos_emb = model.patch_embed_layer.pos_embed.detach().cpu()[0][1:]

    plt.tight_layout(pad=0.1, rect=(0.1, 0.1, 0.9, 0.9))
    fig, axs = plt.subplots(7, 7)
    count = 0
    for i in tqdm(range(196)):
        row = i // 14
        col = i % 14
        if row % 2 == 0 and col % 2 == 0:
            out = torch.cosine_similarity(pos_emb[i], pos_emb, dim=-1)
            fig.add_subplot(7, 7, count+1)
            plt.xticks([])
            plt.yticks([])
            count += 1
            plt.subplots_adjust(0.1, 0.1, 0.9, 0.9)
            plt.imshow(out.reshape(14, 14), vmin=-1, vmax=1)
    for idx, ax in enumerate(axs.flat):
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel('')
        ax.set_ylabel('')
    if not os.path.exists('output'):
        os.mkdir('output')
    plt.savefig('output/position_plot.png', bbox_inches='tight')
    
    
def visualize_attn_weights(mnist, model):
    r"""
    Uses trivial implementation of rollout.
    :param mnist:
    :param model:
    :return:
    """
    num_images = 10
    idxs = torch.randint(0, len(mnist) - 1, (num_images,))
    ims = torch.cat([mnist[idx]['image'][None, :] for idx in idxs]).float()
    ims = ims.to(device)
    attentions = []
    
    def get_attention(model, input, output):
        attentions.append(output.detach().cpu())
        
    # Add forward hook
    for name, module in model.named_modules():
        if 'attn_dropout' in name:
            module.register_forward_hook(get_attention)
    
    model(ims)
    
    # Handle residuals
    attentions = [(torch.eye(att.size(-1)) + att)/(torch.eye(att.size(-1)) + att).sum(dim=-1).unsqueeze(-1) for att in attentions]
    
    result = torch.max(attentions[0], dim=1)[0]
    # Max or mean both are fine
    for i in range(1, 6):
        att = torch.max(attentions[i], dim=1)[0]
        result = torch.matmul(att, result)

    masks = result
    masks = masks[:, 0, 1:]
    for i in range(num_images):
        im_input = torch.permute(ims[i].detach().cpu(), (1, 2, 0)).numpy()
        im_input = im_input[:, :, [2, 1, 0]]
        im_input = (im_input+1)/2 * 255
        mask = masks[i].reshape((14, 14)).numpy()
        
        mask = mask/np.max(mask)
        
        mask = cv2.resize(mask, (224, 224), interpolation=cv2.INTER_LINEAR)[..., None]
        if not os.path.exists('output'):
            os.mkdir('output')
        cv2.imwrite('output/input_{}.png'.format(i), im_input)
        cv2.imwrite('output/overlay_{}.png'.format(i), im_input*mask)


def inference(args):
    # Read the config file
    ######################################
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    #######################################
    
    # Create the model and dataset
    model = VIT(config['model_params']).to(device)
    model.eval()
    mnist = MnistDataset('test', config['dataset_params'],
                         im_h=config['model_params']['image_height'],
                         im_w=config['model_params']['image_width'])
    mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=4)
    
    # Load checkpoint if found
    if os.path.exists(os.path.join(config['train_params']['task_name'],
                                   config['train_params']['ckpt_name'])):
        print('Loading checkpoint')
        model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                      config['train_params']['ckpt_name']), map_location=device))
    else:
        print('No checkpoint found at {}'.format(os.path.join(config['train_params']['task_name'],
                                   config['train_params']['ckpt_name'])))
    with torch.no_grad():
        # Run inference and measure accuracy on number
        get_accuracy(model, mnist_loader)
        # Visualize positional embedding
        visualize_pos_embed(model)
        # Visualize attention weights
        visualize_attn_weights(mnist, model)
    


parser = argparse.ArgumentParser(description='Arguments for vit training')
parser.add_argument('--config', dest='config_path',
                    default='config/default.yaml', type=str)
args = parser.parse_args()
inference(args)