In [1]:
from pathlib import Path
from datetime import datetime

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import cuda
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

from transformers import DistilBertTokenizer, DistilBertModel

from dataloader import BaseDataset, EvaluateDataset
from transformers import ViTModel, ViTConfig

In [2]:
evaluate_dataset = EvaluateDataset("AA")

02/21/2022 02:48:19 - INFO - root -   Load test indices
02/21/2022 02:48:19 - INFO - root -   Load test dataset


In [3]:
evaluate_dataloader = DataLoader(
    dataset=evaluate_dataset,
    batch_size=1,
    num_workers=1,
    shuffle=True,
    collate_fn=evaluate_dataset.collate_data,

)

In [10]:
device = 'cuda:1' if cuda.is_available() else 'cpu'

MAX_LEN = 150
BATCH_SIZE = 64
EPOCHS = 1
LEARNING_RATE = 1e-05
DISTIL_BERT_CHECKPOINT = 'distilbert-base-uncased'
RUN_NAME = 'ROS'
TEST_PATH = '../data/processed/quick_test.csv'
TRAIN_PATH = '../data/ros/train.csv'
MODEL_SAVE = '../models/'

tokenizer = DistilBertTokenizer.from_pretrained(DISTIL_BERT_CHECKPOINT)

# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig(image_size=145, num_channels=1)

# Initializing a model from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)

# Accessing the model configuration
configuration = model.config

In [11]:
class CrossModalBERT(nn.Module):

    def __init__(self):
        super(CrossModalBERT, self).__init__()
        self.distil_bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.configuration = ViTConfig(image_size=145, num_channels=1)
        self.vit = ViTModel(configuration)

        self.linear1 = nn.Linear(1536, 2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, ids, mask, audio):
        bert_out = self.distil_bert(ids, mask)
        text_feat = bert_out.last_hidden_state[:, -1, :] # get bert last hidden state
        vit_out = self.vit(audio)
        audio_feat = vit_out.pooler_output
        cat_feat = torch.cat((audio_feat, text_feat), 1)
        x = self.linear1(cat_feat)
        sig = self.sigmoid(x)
        return sig

model = CrossModalBERT()
model.to('cpu');

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [26]:
dl = BaseDataset("AA")

02/21/2022 03:08:24 - INFO - root -   Load train indices
02/21/2022 03:08:54 - INFO - root -   Load train dataset 0
02/21/2022 03:09:00 - INFO - root -   Load train dataset 1
02/21/2022 03:09:07 - INFO - root -   Load train dataset 2
02/21/2022 03:09:13 - INFO - root -   Load train dataset 3
02/21/2022 03:09:20 - INFO - root -   Load train dataset 4
02/21/2022 03:09:29 - INFO - root -   Load train dataset 5
02/21/2022 03:09:36 - INFO - root -   Load train dataset 6
02/21/2022 03:09:43 - INFO - root -   Load train dataset 7
02/21/2022 03:09:49 - INFO - root -   Load train dataset 8
02/21/2022 03:09:55 - INFO - root -   Load train dataset 9


In [28]:
dataloader = DataLoader(
    dataset=dl,
    batch_size=1,
    num_workers=1,
    shuffle=True,
    collate_fn=dl.collate_data,

)

In [29]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE) 
loss = torch.nn.CrossEntropyLoss()

In [31]:
def evaluate(model, test_dataloader, tokenizer, window_size):
    recall = 0
    model.eval()
    index_corrects = np.repeat(np.arange(0, test_dataloader.dataset.__len__()/5 ), 5).reshape(test_dataloader.dataset.__len__(), 1 )
    with torch.no_grad():
        classname = {0: 'Irrelevant', 1: 'Relevant'}
        correct_pred = defaultdict(lambda: 0)
        total_pred = defaultdict(lambda: 0)
        
        for inputs in test_dataloader:            
            y_pred = []
            
            expert = inputs[0]["expert"]
            data = tokenizer(
                inputs[0]["captions"],
                truncation=True, 
                return_tensors="pt",
                max_length=150,
                padding='max_length')
            index_true = inputs[0]["label_index"]
            
            labels = inputs[0]["label"]
            

            
            if labels is not None and index_true < window_size:
                labels = torch.tensor(
                labels,
                dtype=torch.long )
                
                input_ids = data["input_ids"]
                input_ids = input_ids.repeat(expert.shape[0], 1)

                attention_mask = data["attention_mask"]
                attention_mask = attention_mask.repeat(expert.shape[0], 1)

                torch_zeros = torch.zeros((window_size, 1, 145, 145))
                torch_zeros[:, 0, :128, :] = expert[:window_size].reshape((window_size, 128, 145 ))
                for minibatch in range(int(window_size/100)):
                    range_index = list(range(100 * minibatch, 100 * (minibatch  + 1)))
                    target = labels[range_index]

                    output = model(
                        input_ids[range_index],
                        attention_mask[range_index],
                        torch_zeros[range_index])
                    
                    y_pred += list(output[:,1].to('cpu'))
                
                top_10_rerank = np.argsort(y_pred)[:10] == index_true 
                recall_consult = top_10_rerank.sum()
            else:
                recall_consult = 0
    
            recall += recall_consult/test_dataloader.dataset.__len__()
        return recall
                    

    

In [34]:
model.train()
epoch = 1
window_size = 300
for step, batch in enumerate(dataloader):
    expert = batch[0]["expert"]
    data = tokenizer(
        batch[0]["captions"],
        truncation=True, 
        return_tensors="pt",
        max_length=150,
        padding='max_length')
    index_true = batch[0]["label_index"]
    labels = batch[0]["label"]
    if labels is not None and index_true < window_size:
        labels = torch.tensor(
        labels,
        dtype=torch.long )
        
        input_ids = data["input_ids"]
        input_ids = input_ids.repeat(expert.shape[0], 1)

        attention_mask = data["attention_mask"]
        attention_mask = attention_mask.repeat(expert.shape[0], 1)

        torch_zeros = torch.zeros((window_size, 1, 145, 145))
        torch_zeros[:, 0, :128, :] = expert[:window_size].reshape((window_size, 128, 145 ))
        for minibatch in range(int(window_size/50)):

            range_index = list(range(50*(minibatch), 50*(minibatch  + 1)))
            range_index.append(index_true)

            target = labels[range_index]

            output = model(
                input_ids[range_index],
                attention_mask[range_index],
                torch_zeros[range_index])

            optimizer.zero_grad()
            l = loss(output, target)
            l.backward()
            optimizer.step()
    
            if step % 1 == 0:
                print(f'Epoch: {epoch}, {step}/{len(dataloader)}, Loss:  {l.item()}')

recall = evaluate(model, evaluate_dataloader, tokenizer, 100)
print(f"Epoch:{epoch}, evaluate recall: {recall}")

KeyboardInterrupt: 