<a href="https://colab.research.google.com/github/datvodinh10/project-DD/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import

In [None]:
!pip install gdown
!gdown 1dQ7dqVnBfp4STYMVsnkiLfiAIdFwrSkd # Training data
!gdown 1YedVnk4uKFBPInsa6Mzik0bmZK2Vuf4e # Target label data
!git clone https://github.com/datvodinh10/project-DD.git
%cd project-DD
from src.model.trainer import Trainer
%cd ..

In [None]:
import torch

## Train

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
config = {
    'transformer':{
        'embed_size': 384,      # model's hidden size
        'num_heads':8,          # number of heads in MSA
        'num_layers':4,         # number of encoder/decoder layer
        'max_len': 320,          # max sequence length
        'dropout':0.1,          # dropout rate
        'bias':False,           # attention bias
        'embed_type': 'position' # {'position','learned'}

    },
    'encoder':{
        'type': 'swin_transformer_v2', # {'swin_transformer','swin_transformer_v2','resnet18','resnet50,'vgg'}
        'swin':{
            'img_size':(64,128),
            'embed_dim':48,
            'window_size':8,
            'in_channels':3,
            'dropout':0.1
        },
    },
    
    'device':device,
    'lr':1e-4,
    'scheduler': {
        'active': True,    
        'first_cycle_steps': 400,
        'cycle_mult': 2,  
        'max_lr': 3e-4,          
        'min_lr': 3e-5,   
        'warmup_steps': 50,      
        'gamma': 0.9   
    },
    'dataloader':{
        'num_workers':0,
    },
    'max_grad_norm': 0.5,
    'batch_size':256,
    'num_epochs':150,
    'save_per_epochs': 5,
    'print_type': 'per_epoch' # {'per_epoch','per_batch'}
}

In [None]:
# !unzip -q /content/training_data.zip
# SRC_PATH = "/content/new_train"
# TARGET_PATH = "/content/train_gt.txt"
# MODEL_PATH = "/content"

!unzip -q /kaggle/working/training_data.zip
SRC_PATH = "/kaggle/working/new_train"
TARGET_PATH = "/kaggle/working/train_gt.txt"
MODEL_PATH = "/kaggle/working"

In [None]:
trainer = Trainer(config      = config,
                  IMAGE_PATH  = SRC_PATH,
                  TARGET_PATH = TARGET_PATH,
                  MODEL_PATH  = MODEL_PATH)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Total parameters: {count_parameters(trainer.model)}')

In [None]:
trainer.train()