In [82]:
import glob
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split,DataLoader
from transformers import  TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel
import torch
import gc
import os
torch.manual_seed(42)
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
%%capture
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [11]:
Paths = '/media/delta/S/Signs/*/'

In [12]:
list_of_sign_types = glob.glob(Paths)

In [14]:
list_of_sign_types = [Path(l).name for l in list_of_sign_types]

In [107]:
for l in list_of_sign_types:
    count = len(glob.glob(f'/media/delta/S/Signs/{l}/*'))
    if count > 10 and l != 'Unidentified':
        print(count,'photos')

30 photos
33 photos
13 photos
36 photos
116 photos
34 photos
13 photos
13 photos
26 photos
13 photos
31 photos
28 photos
30 photos
30 photos
18 photos
24 photos
12 photos
82 photos
50 photos
51 photos
114 photos
16 photos
20 photos
13 photos
18 photos
11 photos
14 photos
13 photos
12 photos
18 photos
29 photos


In [55]:
SignName = []
SignPath = []
for l in list_of_sign_types:
    count = len(glob.glob(f'/media/delta/S/Signs/{l}/*'))
    if count > 10  and l != 'Unidentified':
        sign_name = l
        for p in (glob.glob(f'/media/delta/S/Signs/{l}/*')):
            SignPath.append(p)
            SignName.append(sign_name)

In [56]:
data = {'CAPTION': SignName, 'IMAGEPATH': SignPath}

In [57]:
df = pd.DataFrame(data)

In [59]:
print('number of avaialable training data:', len(df),  'photos')

number of avaialable training data: 961 photos


In [94]:
class CustomDataset(Dataset):
    def __init__(self,ds, tokenizer,feature_extractor):
        self.Pixel_Values = []
        self.Labels = []
        for i,r in ds.iterrows():
            try:
                image_path = r['IMAGEPATH']             #A table in csv format with 2 columns IMAGEPATH and CAPTION
                labels = r['CAPTION']
                labels = str(labels)
                image = Image.open(str(image_path)).convert("RGB")
                pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
                self.Pixel_Values.append(pixel_values)
                labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids
                labels[labels == tokenizer.pad_token_id] = -100
                self.Labels.append(labels)
            except:
                print(labels)
        
    def __len__(self):
        return len(self.Pixel_Values)

    def __getitem__(self, idx):
        return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]}

In [95]:
dataset = CustomDataset(df,tokenizer,feature_extractor)

D4-2-2-Q01A
D4-2-2-Q01A
D4-3A
D4-6B
D4-6B
G1-1
G6-1
R1-2B
G2-1
R4-1C
R4-1C
R5-35 (D)
W8-2B
W8-2B


In [96]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"][0] for example in examples])   #0 to change from [1,3,224,224] to  [3,224,224]  torch stack will add it back depends on the batch size,
    labels = torch.stack([example["labels"][0] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [97]:
dataloader = DataLoader(dataset, batch_size=12, shuffle=True,num_workers=12,collate_fn=collate_fn)

In [104]:
%%capture
optimizer = torch.optim.AdamW(model.parameters(), lr=1)
scaler = torch.cuda.amp.GradScaler()
writer = SummaryWriter()
model.train()

In [105]:
step = 0
total_loss = 0
for epoch in range(100):
    for i, b in enumerate(dataloader):
        pixel_values = b['pixel_values'].to(device)
        labels = b['labels'].to(device)
        with torch.cuda.amp.autocast():
            loss = model(labels=labels, pixel_values=pixel_values).loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        if step % 100 == 0:
            print(total_loss/100,step)
            total_loss = 0
        step += 1

0.11997478485107421 0
11.834055395126343 100
11.883658781051636 200
11.91219877243042 300
11.881713981628417 400
11.868750429153442 500
11.881986074447632 600
11.92239993095398 700
11.876637926101685 800
11.872782335281372 900


KeyboardInterrupt: 

In [106]:
# Failed.... Pretrained model wasnt trained on non-human language D4-4xssdfasdfasdf

In [None]:
# Technically, everyone absolutely lovesss my igenious design. I rest my case.