In [1]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torchvision.transforms._transforms_video import ToTensorVideo, RandomResizedCropVideo, RandomHorizontalFlipVideo, NormalizeVideo

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
vid_path = '../../dataset/hmdb51_vid/'
split_path = '../../dataset/hmdb51_split/'
class_names = sorted(os.listdir(vid_path))
class_embed = pickle.load(open('./metas/class_embed_sbert.pkl', 'rb'))
print(class_embed[class_names[0]].shape)

(768,)


In [4]:
total_embed = list()
for n_i, name in enumerate(class_names):
    total_embed.append(torch.FloatTensor(class_embed[name]).view(1, -1))
total_embed = torch.cat(total_embed, dim=0)

In [5]:
img_size = 224
feat_dim = 1024
inter_dim = feat_dim//4
lang_dim = 768
clip_len = 20
fps = 5

In [6]:
to_tensor = ToTensorVideo()
random_crop = RandomResizedCropVideo(img_size, scale=(0.95, 1.05))
random_flip = RandomHorizontalFlipVideo()
normalizer = NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
curr_transform = torchvision.transforms.Compose([to_tensor, random_crop, random_flip, normalizer])

train_dataset = torchvision.datasets.HMDB51(vid_path, 
                                           split_path, 
                                           frames_per_clip=clip_len, step_between_clips=clip_len//2, 
                                           frame_rate=fps, fold=1, train=True, 
                                           transform=curr_transform)
test_dataset = torchvision.datasets.HMDB51(vid_path, 
                                           split_path, 
                                           frames_per_clip=clip_len, step_between_clips=clip_len//2, 
                                           frame_rate=fps, fold=1, train=False, 
                                           transform=curr_transform)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=423.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=423.0), HTML(value='')))




In [7]:
train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle=False)

In [8]:
class Model(nn.Module):
    def __init__(self, feat_dim, inter_dim, lang_dim, img_size, clip_len):
        super(Model, self).__init__()
        
        self.feat_dim = feat_dim
        self.inter_dim = inter_dim
        self.lang_dim = lang_dim
        self.img_size = img_size
        self.clip_len = clip_len
    
        self.backbone = torchvision.models.resnet50(pretrained=True, progress=True)
        for c_idx, child in enumerate(self.backbone.children()):
            if c_idx < 5:
                for param in child.parameters():
                    param.requires_grad = False
        
        modules=list(self.backbone.children())[:7]
        self.backbone=nn.Sequential(*modules)
        
        self.adain_linear_1 = nn.Linear(lang_dim, feat_dim*2)
        self.adain_linear_2 = nn.Linear(lang_dim, inter_dim*2)
        self.cls_query_linear = nn.Linear(lang_dim, feat_dim)
        
        self.dec_conv1 = nn.Conv2d(feat_dim, inter_dim, 3, padding=1, bias=False)
        self.dec_relu = nn.ReLU()
        self.dec_dropout = nn.Dropout(0.1)
        self.dec_conv2 = nn.Conv2d(inter_dim, 1, 1)
        
        self.attn_feat_linear = nn.Linear(feat_dim, feat_dim)
        
        self.cos = nn.CosineSimilarity(dim=1)
        
    def forward(self, img, cls_embed, total_embed):
        total_adain_1 = self.adain_linear_1(total_embed)
        total_adain_2 = self.adain_linear_2(total_embed)

        cls_embed = cls_embed.view(1, -1)
        cls_query = self.cls_query_linear(cls_embed)
        
        # get feature
        feature = self.backbone(img)
        
        # calculate mean and std of feature
        feature_mean = torch.mean(feature, dim=[2, 3], keepdim=True)# B D
        feature_std = torch.std(feature, dim=[2, 3], keepdim=True) # B D
        
        # adain to the feature
        normed_feature = torch.div(feature - feature_mean, feature_std+1e-8)
        adain_feature = torch.mul(normed_feature, total_adain_1[:, :self.feat_dim, None, None])
        adain_feature = torch.add(adain_feature, total_adain_1[:, self.feat_dim:, None, None])
        
        # decoder to adain_ed result
        dec_result = self.dec_conv1(adain_feature)
        dec_result = self.dec_dropout(self.dec_relu(dec_result))
        
        # adain to the decoder result
        dec_mean = torch.mean(dec_result, dim=[2, 3], keepdim=True)
        dec_std = torch.std(dec_result, dim=[2, 3], keepdim=True)
        
        dec_result = torch.div(dec_result - dec_mean, dec_std+1e-8)
        dec_result = torch.mul(dec_result, total_adain_2[:, :self.inter_dim, None, None])
        dec_result = torch.add(dec_result, total_adain_2[:, self.inter_dim:, None, None])
        
        # final layer of decoder
        attn_map = self.dec_conv2(dec_result)
        
        # get attended feature
        feature_flat = feature.squeeze(0).view(self.feat_dim, -1)
        attn_map_flat = attn_map.squeeze(1).view(attn_map.shape[0], -1)
        
        attn_feat = torch.mm(attn_map_flat, feature_flat.permute(1, 0))
        
        # calculate the cosine similarity between attn_feat (N_c x D) and current cls_query (1 x D)
        cos_sim = self.cos(self.attn_feat_linear(attn_feat), cls_query)
        
        # return the logit, attention map(dec_result). Later, I would only need attn_feat. 
        return cos_sim, attn_map
        

In [9]:
model = Model(feat_dim, inter_dim, lang_dim, img_size, clip_len)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-8)

In [16]:
def train(model, optimizer, sample, total_embed):
    model.train()
    
    # define loss
    criterion = nn.CrossEntropyLoss()
    
    # reset gradient
    optimizer.zero_grad()
    
    # calculate input
    frame_index = np.random.randint(low=0, high=clip_len)
    img = sample[0][0][:, frame_index:(frame_index+1), :, :].permute(1, 0, 2, 3)
    cls_embed = torch.FloatTensor(class_embed[class_names[sample[2].item()]])
    cls_label = sample[2]   

    # get result
    cos_sim, attn_map = model(img.to(device), 
                              cls_embed.to(device), 
                              total_embed.to(device))
    
    # optimize one step
    loss = criterion(cos_sim.view(1, -1), cls_label.to(device))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    return loss.item()

In [17]:
def test(model, sample, total_embed):
    model.eval()
    
    # define loss
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        # calculate input
        frame_index = clip_len//2
        img = sample[0][0][:, frame_index:(frame_index+1), :, :].permute(1, 0, 2, 3)
        cls_embed = torch.FloatTensor(class_embed[class_names[sample[2].item()]])
        cls_label = sample[2]   

        # get result
        cos_sim, attn_map = model(img.to(device), 
                                  cls_embed.to(device),
                                  total_embed.to(device))
        
        # get loss and accuracy
        loss = criterion(cos_sim.view(1, -1), cls_label.to(device))
        esti_label = torch.argmax(cos_sim)
        accu = esti_label.item() == cls_label.item() 
    
    return loss.item(), accu, attn_map

In [None]:
max_epoch = 200
dir_name = 'first'
writer = SummaryWriter(logdir=os.path.join('./runs', dir_name))
n_iter = 0
max_accu = 0.0
save_stride = 50
tmp_path = './checkpoint.tar'

for epoch in range(max_epoch):
    if epoch > 0:
        checkpoint = torch.load(tmp_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    with tqdm(total=len(train_dataloader)) as pbar:
        for idx, sample in enumerate(train_dataloader):
            curr_loss = train(model, optimizer, sample, total_embed)
            writer.add_scalar('train/loss', curr_loss, n_iter)
            n_iter += 1
            pbar.update(1)

    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': curr_loss,
            }, tmp_path)
    
    test_loss = 0.0
    test_accu = 0.0
    np.random.seed(epoch)
    plot_idx = np.random.randint(0, len(test_dataloader))
    
    checkpoint = torch.load(tmp_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
    with tqdm(total=len(test_dataloader)) as pbar:
        for idx, sample in enumerate(test_dataloader):
            loss, accu, attn_map = test(model, sample, total_embed)
            test_loss += loss / len(test_dataloader)
            if accu:
                test_accu += accu / len(test_dataloader)
            pbar.update(1)
            
            if idx == plot_idx:
                fig, ax = plt.subplots(1, 2, figsize=(8, 4))
                ax[0].imshow(sample[0][0][:, 10, :, :].permute(1, 2, 0))
                idx = sample[2].item()
                ax[1].imshow(attn_map[idx][0].cpu())
                plt.suptitle(accu)
                plt.savefig('./tmp_img.png')
                tmp_img = torch.FloatTensor(plt.imread('./tmp_img.png')[:, :, :3]).permute(2, 0, 1)
                writer.add_image('test/attn_map', tmp_img, epoch)           
                
    writer.add_scalar('test/loss', test_loss, epoch)
    writer.add_scalar('test/accuracy', test_accu, epoch)
    
    max_accu = max(test_accu, max_accu)
    
    if test_accu == max_accu:
        torch.save(model.state_dict(), './single_model/{}_best.pth'.format(dir_name))
    torch.save(model.state_dict(), './single_model/{}_recent.pth'.format(dir_name))
    if (epoch+1) % save_stride == 0:
        torch.save(model.state_dict(), './single_model/{}_{}.pth'.format(dir_name, epoch+1))
        

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1436.0), HTML(value='')))

In [None]:
plt.imread('./tmp_img.png')[:, :, :3].shape