In [None]:
# For Colab users
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.insert(0,'/content/drive/{path to project directory}')

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
import pandas as pd

from data_utils import MLDataset, collate_fn
# from modeling import Seq2SeqModel
from modeling_challenge import Seq2SeqModel

In [2]:
assert torch.cuda.is_available()

# Use 0th GPU for training
torch.cuda.set_device(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# You can add or modify your Seq2SeqModel's hyperparameter (keys and values)
kwargs = {
    'hidden_dim': 256,       # Hidden dimension size for RNN
    'nhead': 4,              # Number of attention heads in the Transformer
    'dec_layers': 4,         # Number of layers in the Transformer decoder
    'dim_feedforward': 1024, # Dimension of feedforward layers in the Transformer
    'dropout': 0.2,          # Dropout rate for the Transformer
    'enc_layers': 3,         # Number of RNN layers in the encoder
    'rnn_dropout': 0.3,      # Dropout rate for the RNN in the encoder
    'max_length': 11,        # Maximum length of the sequences
    'cnn_settings': {        # Settings for the CustomCNN
        'block1_dim': 32,
        'block2_dim': 64,
        'block3_dim': 128,
        'fc_dim': 256,
        'model_type': 'VGG'  # Type of CNN ('VGG' or 'ResNet')
    },
}
kwargs_generate = {
    # you can add arguments for your model's generate function
    "max_length": 10
}
BATCH_SIZE = 128
NUM_CLASSES = 28

In [6]:
# You can use your own model class and model path
model = Seq2SeqModel(num_classes=NUM_CLASSES, **kwargs).to(device)
print(model)

model_path = './model.pt'

Seq2SeqModel(
  (encoder): Encoder(
    (cnn): CustomCNN(
      (block1): VGGBlock(
        (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (block2): VGGBlock(
        (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (block3): VGGBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Linear(in_features=1152, out_features=256, bias=True)
 

In [7]:
# Do not modify this cell!

test_ds = MLDataset('data_final/imgs/test', 'data_final/labels/test_dummy.json')
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

state = torch.load(model_path)
model.load_state_dict(state["model"])
model.eval()

ids = []
predictions = []
for batch_idx, (data, _, lengths) in enumerate(tqdm(test_dl)):       
    data = data.to(device) # (B, T, H, W, C)
    
    # start tokens should be located at the first position of the decoder input
    start_tokens = (torch.ones([data.size(0), 1]) * 27).to(torch.long).to(device)
    with torch.no_grad():
        generated_tok = model.generate(data, lengths, start_tokens, **kwargs_generate) # (B, T)

    for i in range(generated_tok.size(0)):
        sample_idx = batch_idx * BATCH_SIZE + i + 1
        ids.append(sample_idx)
        pred = 0
        for j, tok in enumerate(generated_tok[i][:lengths[i].int()].tolist()):
            pred += tok * math.pow(28, j)
        predictions.append(int(pred))

sub_df = pd.DataFrame({
    'id': ids,
    'prediction': predictions
})

sub_df.to_csv('submission.csv', index=False)
print("Created submission file successfully!")

  0%|          | 0/69 [00:00<?, ?it/s]

100%|██████████| 69/69 [00:07<00:00,  9.36it/s]

Created submission file successfully!



