<a href="https://colab.research.google.com/github/datvodinh10/project-DD/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import

In [None]:
!pip install timm

In [None]:
!pip install gdown
!gdown 1lR0b1QBIsXk9JqL9__HgJQi0PdWeNxKi
!gdown 1b4fCTrnfKnR0GHm1XCve9nhy7JyrqahR


In [None]:
!unzip -q /content/training_data.zip
# !unzip /kaggle/working/training_data.zip

In [None]:
import torch

In [None]:
!git clone https://github.com/datvodinh10/project-DD.git
%cd project-DD

In [None]:
from src.model.trainer import Trainer

In [None]:
%cd ..

In [None]:
SRC_PATH = "/content/new_train"
TARGET_PATH = "/content/train_gt.txt"
MODEL_PATH = "./"

# SRC_PATH = "/kaggle/working/new_train"
# TARGET_PATH = "/kaggle/working/train_gt.txt"
# MODEL_PATH = "./data"

## Train

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
config = {
    'transformer':{
        'embed_size': 384,      # model's hidden size
        'num_heads':8,          # number of heads in MSA
        'num_layers':4,         # number of encoder/decoder layer
        'max_len': 320,          # max sequence length
        'dropout':0.2,          # dropout rate
        'bias':False,           # attention bias
        'embed_type': 'position' # {'position','learned'}

    },
    'encoder':{
        'type': 'swin_transformer', # {'resnet18','resnet50,'vgg','swin_transformer','vision_transformer'}

        'swin':{
            'img_size':(64,256),
            'embed_dim':48,
            'window_size':8,
            'in_channels':1
        },
    },
    
    'device':device,
    'lr':1e-4,
    'scheduler': {
        'active': True,    
        'first_cycle_steps': 400,
        'cycle_mult': 1,  
        'max_lr': 2e-4,          
        'min_lr': 5e-5,   
        'warmup_steps': 50,      
        'gamma': 0.9   
    },
    'dataloader':{
        'num_workers':0,
    },
    'max_grad_norm': 0.5,
    'batch_size':256,
    'num_epochs':100
}

In [None]:
trainer = Trainer(config      = config,
                  IMAGE_PATH  = SRC_PATH,
                  TARGET_PATH = TARGET_PATH)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Total parameters: {count_parameters(trainer.model)}')

In [None]:
trainer.train()

In [None]:
save_dict = {
    'state_dict':trainer.model.state_dict(),
    'config':config,
    'vocab_size':trainer.vocabulary.vocab_size,
    'letter_to_idx': trainer.vocabulary.letter_to_idx,
    'idx_to_letter': trainer.vocabulary.idx_to_letter
}
file_path = f"{MODEL_PATH}/model_{config['encoder']['type']}_{config['num_epochs']}.pt"
torch.save(save_dict, file_path)

## Predict

In [None]:
from src.utils.inference import Inference
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch

In [None]:
infer = Inference(MODEL_PATH   = file_path)

In [None]:
idx = torch.randint(1000,())
img = Image.open(os.path.join(SRC_PATH,os.listdir(SRC_PATH)[idx]))
print(infer.predict(img,sampling='soft',temperature=0.2))
print(infer.predict(img,sampling='repeat_penalty',temperature=0.2))
print(infer.predict(img,sampling='hard'))
print(infer.predict(img,sampling='top_k',k=25))
print(infer.predict(img,sampling='top_p',p=0.8))

In [None]:
plt.imshow(img)