In [1]:
# !pip install einops

In [2]:
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import json
import gc
import os
import copy
import time
import re
import torchvision
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from datasets import load_dataset 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
# from einops import rearrange, repeat 



># **Test**

> # **data analysis**

In [3]:
processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")

Downloading (…)rocessor_config.json:   0%|          | 0.00/503 [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/453 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [4]:
df = pd.read_csv('/kaggle/input/roco-brain/train/kaggle/working/traindata.csv')
df.head()

Unnamed: 0.1,Unnamed: 0,id,name,caption
0,5,ROCO_00008,PMC4805615_13244_2016_481_Fig12_HTML.jpg,A 3-year-old child with visual difficulties. ...
1,67,ROCO_00085,PMC3201077_AJNS-5-70-g001.jpg,Contrast MRI head axial section showing an ir...
2,86,ROCO_00105,PMC3272911_JCIS-1-43-g003.jpg,T1 axial image shows hypointense lesions in t...
3,87,ROCO_00106,PMC3174812_z9k0091109120002.jpg,An inflated representation of the right hemis...
4,134,ROCO_00164,PMC2939498_CRM2010-846534.003.jpg,MRI of the brain showing no mass or enhancing...


In [5]:
len(df)

2874

In [6]:
path = '/kaggle/input/roco-brain/train/kaggle/working/train_images/'

In [7]:
df = df[df['name'].isin(os.listdir(path))]
len(df)

2874

In [8]:
# max(df['caption'].apply(lambda x: len(processor(text=x)['input_ids'])))

> # **GITMODEL**

In [9]:
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

Downloading config.json:   0%|          | 0.00/2.82k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/707M [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

In [10]:
model.load_state_dict(torch.load('/kaggle/input/roco-git/ROCO_Git.pth',map_location='cpu'))

<All keys matched successfully>

In [11]:
class MyDataset(Dataset):
    def __init__(self,dataframe, path, processor, max_length,return_image:bool=False):
        super().__init__()
        self.data = dict({})
        self.data['image'] = list(dataframe['name'])
        self.data['caption'] = list(dataframe['caption'])
        self.path = path
        self.max_length = max_length
        self.processor = processor
        self.return_image = return_image
    def __len__(self):
        return len(self.data['image'])
    def __getitem__(self,idx):
        image = self.path+self.data['image'][idx]
        try:
            image = Image.open(image).convert("RGB")
        except:
            print(f"error on read image {idx}")
            idx = 0
            image = Image.open(self.path+self.data['image'][idx]).convert("RGB")
        caption = self.data['caption'][idx]
        encoding = self.processor(images=image, text=caption ,padding="max_length", return_tensors="pt",max_length = self.max_length)
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding['label'] = copy.deepcopy(encoding['input_ids'])
        encoding['label'][ encoding['label'] == 0] = -100
        if self.return_image: return encoding,image
        else: return encoding

In [12]:
trainset = MyDataset(df,path,processor,700)
# for k, v in trainset[0].items():
#     print(k,v.shape)

In [13]:
train_loader = DataLoader(trainset,batch_size=16,drop_last=True,shuffle=True)

In [14]:
def train(model,train_loader,epochs=100):
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    if torch.cuda.is_available():
        device='cuda'
        model = nn.DataParallel(model,device_ids=[0,1])
        model.to('cuda')
    else:
        device='cpu'
    model.train()
    time_limit = 3600*12-30
    start_time = time.time()
    for epoch in range(epochs):
        for idx,batch in enumerate(train_loader):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            att_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)
            outputs = model(input_ids=input_ids,
                            attention_mask=att_mask,
                            pixel_values=pixel_values,
                            labels=label)
            
            loss = outputs.loss.mean()
            print(f"epoch: {epoch} batch: {idx} -----------loss:{loss.item()}")
            loss.backward()
            optimizer.step()
            del input_ids, pixel_values, outputs, loss, att_mask, label
            torch.cuda.empty_cache()
            gc.collect()
        torch.save(model.module.state_dict(),'ROCO_Git.pth')
        period = time.time() - start_time
        speed = period/(epoch+1)
        print(f"1 epoch speed {speed}")
        if speed + period > time_limit:
            break

In [15]:
train(model,train_loader)