In [None]:
import torch
import time
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import ViltProcessor, ViltForImagesAndTextClassification, ViltConfig, ViltModel, AdamW
import requests
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import ast
import shutil
import torch.nn.functional as F


In [None]:
np.random.seed(97)
torch.random.manual_seed(97)

<torch._C.Generator at 0x7fb710234bb0>

In [None]:
def scoring(pred, target, topk):
    pred = torch.argsort(pred, dim=1, descending=True)
    pred = pred.cpu().detach().numpy()  # [batch_size, hashtag_vocab_size]
    target = target  # [batch_size, hashtag_vocab_size]
    tag_label = []
    for this_data in target:
        tag_label.append([])
        for idx, each_tag in enumerate(this_data):
            if each_tag != 0:
                tag_label[-1].append(idx)
    precision = []
    recall = []
    f1 = []
    print(pred)
    for i in range(len(pred)):
        this_precision = 0
        this_recall = 0
        this_f1 = 0
        if (len(tag_label[i]) != 0):
            for j in range(topk):
                if pred[i][j] in tag_label[i]:
                    this_precision += 1
            for j in range(len(tag_label[i])):
                if tag_label[i][j] in pred[i][:topk]:
                    this_recall += 1
            this_precision /= topk
            this_recall /= len(tag_label[i])
            if this_precision != 0 and this_recall != 0:
                this_f1 = 2 * (this_precision * this_recall) / (this_precision + this_recall)
        precision.append(this_precision)
        recall.append(this_recall)
        f1.append(this_f1)
    return precision, recall, f1

In [None]:

device = torch.device('cuda:0')
print(device)

cuda:0


In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, pretrained_model='dandelin/vilt-b32-mlm'):
        super(ClassificationModel, self).__init__()
        self.vilt = ViltModel.from_pretrained(pretrained_model)
        self.linear = nn.Linear(768,1000)
        self.norm = nn.LayerNorm(1000)
        self.acti = nn.GELU()
        self.linear2= nn.Linear(1000,2000)
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        pixel_values=None,
        pixel_mask=None,
        head_mask=None,
        inputs_embeds=None,
        image_embeds=None,
        image_token_type_idx=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        pooler_output = self.vilt(input_ids, token_type_ids, attention_mask, pixel_values, pixel_mask).pooler_output
        predict = self.linear(pooler_output)
        predict = self.norm(predict)
        predict = self.acti(predict)
        predict = self.linear2(predict)
        return predict

In [None]:
model = ClassificationModel()
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
model = model.to(device)

Some weights of the model checkpoint at dandelin/vilt-b32-mlm were not used when initializing ViltModel: ['mlm_score.bias', 'mlm_score.transform.LayerNorm.bias', 'mlm_score.transform.dense.bias', 'mlm_score.decoder.weight', 'mlm_score.transform.LayerNorm.weight', 'mlm_score.transform.dense.weight']
- This IS expected if you are initializing ViltModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViltModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class SSTDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,str(self.df.iloc[idx]['Unnamed: 0']) + '.jpg')
        image = Image.open(img_name)
#         text = str(self.df.iloc[idx]['after_contents']) # without location information
        text = str(self.df.iloc[idx]['concat_category_location']) # with location information
        label = ast.literal_eval(self.df.iloc[idx]['new_hashtags_2000_onehot'])
        labels = ast.literal_eval(self.df.iloc[idx]['new_hashtags_2000_onehots'])
        process_output = self.processor(image, text,truncation=True, padding = 'max_length', return_tensors="pt")
        for k,v in process_output.items():
            process_output[k] = v.squeeze()
        process_output['labels'] = labels
        process_output['label'] = label
        
        return process_output

In [None]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    token_type_ids = [item['token_type_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    label = [item['label'] for item in batch]

    # create padded pixel values and corresponding pixel mask
    encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

    # create new batch
    batch = {}
    batch['input_ids'] = torch.stack(input_ids)
    batch['attention_mask'] = torch.stack(attention_mask)
    batch['token_type_ids'] = torch.stack(token_type_ids)
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = torch.LongTensor(labels)
    batch['label'] = torch.LongTensor(label)

    return batch

In [None]:

batch_size = 128
dataset = SSTDataset('post_dataset.csv','./drop_image')
train_size = int(0.8 * len(dataset))
test_val_size = len(dataset) - train_size
train_dataset, test_val_dataset = torch.utils.data.random_split(dataset, [train_size, test_val_size])
val_size = int(0.5 * len(test_val_dataset))
test_size = len(test_val_dataset) - val_size
val_dataset, test_dataset = torch.utils.data.random_split(test_val_dataset, [val_size,test_size])
train_loader = DataLoader(train_dataset, batch_size = batch_size,collate_fn=collate_fn, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, collate_fn=collate_fn, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size,collate_fn=collate_fn, shuffle = True)


In [None]:
def save_checkpoint(state, is_best, model_save_path, filename):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(model_save_path, 'model_best.pth.tar'))

save_path = './saved_model/'
save_path = os.path.join(save_path,
                             time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())))
writer = SummaryWriter(log_dir=save_path)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# HyperParameter

lr = 0.0001
weight_decay = 0.01
optimizer = AdamW(model.parameters(), lr=lr, weight_decay = weight_decay)
threshold = 5
topk = 5
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda = lambda epoch: 0.95 ** epoch)



In [None]:
with open(os.path.join(save_path,"config.txt"),"w") as f:
    f.write("batch_size = " + str(batch_size)+"\n")
    f.write("learning_rate = " + str(lr)+ "\n")
    f.write("threshold = " + str(threshold)+ "\n")
    f.write("weight_decay = " + str(weight_decay)+ "\n")
    f.write("optim = " + str(optimizer) +"\n")
    f.write("topk = " + str(topk)+ "\n")
    f.write("Location = True " + "\n")
    f.write("hashtag = 2000 \n")
    f.write(str(criterion)+'\n')
    f.write("Model 1")

In [None]:
global_steps = 0
epoch = 0
max_f1 = 0
stop_cnt = 0
while True:
    epoch += 1
    model.train()
    precision = []
    recall = []
    f1 = []
    train_loss = 0
    cnt = 0
    for batch in tqdm(train_loader, total=len(train_loader)):
        batch = {k:v.to(device) for k,v in batch.items()}
        logits = model(batch['input_ids'],batch['attention_mask'],batch['token_type_ids'],batch['pixel_values'],batch['pixel_mask'])
        label = batch['label'].to(device)
        labels = batch['labels'].to(device)
        print(logits)
        loss = criterion(logits,label.float())
        train_loss += loss.item()
        
        cnt += len(label)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        batch_p, batch_r, batch_f1 = scoring(logits, labels, topk) # topk
        precision.extend(batch_p)
        recall.extend(batch_r)
        f1.extend(batch_f1)

        writer.add_scalar(tag='batch_precision',
                                scalar_value=sum(batch_p)/len(batch_p),
                                global_step=global_steps)
        writer.add_scalar(tag='batch_recall',
                            scalar_value=sum(batch_r) / len(batch_r),
                            global_step=global_steps)
        writer.add_scalar(tag='batch_f1',
                            scalar_value=sum(batch_f1) / len(batch_f1),
                            global_step=global_steps)
        writer.add_scalar(tag='batch_loss',
                            scalar_value=loss.item(),
                            global_step=global_steps)
        global_steps += 1

    writer.add_scalar(tag='train_precision',
                            scalar_value=sum(precision) / len(precision),
                            global_step=epoch)
    writer.add_scalar(tag='train_recall',
                        scalar_value=sum(recall) / len(recall),
                        global_step=epoch)
    writer.add_scalar(tag='train_f1',
                        scalar_value=sum(f1) / len(f1),
                        global_step=epoch)
    writer.add_scalar(tag='train_loss',
                        scalar_value=train_loss / cnt,
                        global_step=epoch)
    scheduler.step()
    
    model.eval()
    precision = []
    recall = []
    f1 = []
    val_loss = 0
    cnt = 0
    for batch in tqdm(val_loader, total=len(val_loader)):
        batch = {k:v.to(device) for k,v in batch.items()}
        label = batch['label'].to(device)
        labels = batch['labels'].to(device)
        with torch.no_grad():
            logits = model(batch['input_ids'],batch['attention_mask'],batch['token_type_ids'],batch['pixel_values'],batch['pixel_mask'])
        print(logits)
        loss = criterion(logits,label.float())
        val_loss += loss.item()
        cnt += len(label)
        batch_p, batch_r, batch_f1 = scoring(logits, labels, topk) #topk
        precision.extend(batch_p)
        recall.extend(batch_r)
        f1.extend(batch_f1)
    val_p = sum(precision)/len(precision)
    val_r = sum(recall) / len(recall)
    val_f1 = sum(f1) / len(f1)
    writer.add_scalar(tag='val_precision',
                        scalar_value=val_p,
                        global_step=epoch)
    writer.add_scalar(tag='val_recall',
                        scalar_value=val_r,
                        global_step=epoch)
    writer.add_scalar(tag='val_f1',
                        scalar_value=val_f1,
                        global_step=epoch)
    writer.add_scalar(tag='val_loss',
                        scalar_value=val_loss / cnt,
                        global_step=epoch)

    if val_f1 > max_f1:
        max_f1 = val_f1
        stop_cnt = 0
        is_best = True
    else:
        stop_cnt += 1
        is_best = False

    save_checkpoint({
        'epoch': epoch,
        'model': model,
        'state_dict': model.state_dict(),
        'precision': val_p,
        'recall': val_r,
        'f1-score': val_f1,
        'optimizer': optimizer.state_dict()
    }, is_best, save_path, os.path.join(save_path, 'epoch' + str(epoch) + '.pth.tar'))

    if stop_cnt > threshold: # threshold
        print("Training finished.")
        break