In [58]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models


import numpy as np
import wandb
import pandas as pd
import cv2
from tqdm import tqdm
import argparse
from PIL import Image

from sklearn.metrics import f1_score, precision_score, accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/digit-recognizer/sample_submission.csv
/kaggle/input/digit-recognizer/train.csv
/kaggle/input/digit-recognizer/test.csv


In [59]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [60]:
class Model_VGG(nn.Module):
    def __init__(self, model_type):
        super(Model_VGG, self).__init__()

        if model_type == "vgg_16":
            self.model_backbone = models.vgg16(pretrained=True)
        elif model_type == "vgg_19":
            self.model_backbone = models.vgg19(pretrained=True)
        self.model_backbone.classifier[6] = nn.Linear(4096, 10)

    def forward(self, x):
        x = self.model_backbone(x)
        return x

In [61]:
class Digit_Dataset(Dataset):
    def __init__(self, df_data, transforms=None):
        self.data = df_data
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data.iloc[idx, 1:].values.reshape(28, 28).astype(np.uint8)
        # 3 channels
        img = np.stack((img,) * 3, axis=-1)
        label = self.data.iloc[idx, 0]
        if self.transforms:
            img = self.transforms(img)
        return img, label


In [62]:
class Training_Model:
    def __init__(self, model, optimizer, criterion, dict_dataloader):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = dict_dataloader["train_loader"]
        self.val_loader = dict_dataloader["val_loader"]
        self.best_f1_score = 0

    def train(self, epochs):
        print("============TRAINING START {}============".format(epochs))
        self.model.train()
        total_loss = 0
        preds, targets = [], []
        for batch_idx, (data, target) in enumerate(tqdm(self.train_loader)):
            data = data.to(device)
            target = target.to(device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            preds += predicted.tolist()
            targets += target.tolist()

        total_loss = round(total_loss / len(self.train_loader), 4)
        precision_scr = round(
            precision_score(targets, preds, average="macro"),
            4,
        )
        accuracy_scr = round(
            accuracy_score(targets, preds),
            4,
        )
        train_f1_scr = round(
            f1_score(targets, preds, average="macro"),
            4,
        )
        print(
            "train_loss: {}, train_precision: {}, train_accuracy: {}, train_f1_score: {}".format(
                total_loss, precision_scr, accuracy_scr, train_f1_scr
            )
        )
        wandb.log(
            {
                "train_loss": total_loss,
                "train_precision_score": precision_scr,
                "train_accuracy_score": accuracy_scr,
                "train_f1_score": train_f1_scr,
            }
        )

    def validation(self, epochs):
        print("============VAL START {}============".format(epochs))
        self.model.eval()
        total_loss = 0
        preds, targets = [], []
        for batch_idx, (data, target) in enumerate(tqdm(self.val_loader)):
            data = data.to(device)
            target = target.to(device)

            with torch.no_grad():
                output = self.model(data)
                loss = self.criterion(output, target)
                total_loss += loss.item()
                _, predicted = torch.max(output.data, 1)

                preds += predicted.tolist()
                targets += target.tolist()
        total_loss = round(total_loss / len(self.val_loader), 4)
        precision_scr = round(precision_score(targets, preds, average="macro"), 4)
        accuracy_scr = round(accuracy_score(targets, preds), 4)
        val_f1_scr = round(f1_score(targets, preds, average="macro"), 4)
        print(
            "val_loss: {}, val_precision: {}, val_accuracy: {}, val_f1_score: {}".format(
                total_loss, precision_scr, accuracy_scr, val_f1_scr
            )
        )
        wandb.log(
            {
                "val_loss": total_loss,
                "val_precision_score": precision_scr,
                "val_accuracy_score": accuracy_scr,
                "val_f1_score": val_f1_scr,
            }
        )
        if val_f1_scr > self.best_f1_score and val_f1_scr > 0.9:
            self.best_f1_score = val_f1_scr
            torch.save(
                self.model.state_dict(),
                "/kaggle/working/{}_f1_{}.pt".format("vgg_16", val_f1_scr),
            )
#             torch.jit.save(
#                 torch.jit.script(self.model),
#                 self.model.state_dict(),
#                 "/kaggle/working/{}_jit_f1_{}.pt".format("vgg_16", val_f1_scr),
#             )

In [63]:
def get_loader(csv_file, transforms, batch_size=64):
    dataset = Digit_Dataset(csv_file, transforms)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader


def get_input():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="vgg_16")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--weight_decay", type=float, default=0.0001)
    args = parser.parse_args()
    return args


def config_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

In [64]:
config_seed(10)

batch_size = 128
epochs = 50
lr = 0.001
weight_decay = 0.0001

data = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")

train_data, val_data = train_test_split(data, test_size=0.2, random_state=10)

transforms = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,)),
    ]
)

train_loader = get_loader(train_data, transforms=transforms, batch_size=batch_size)
val_loader = get_loader(val_data, transforms=transforms, batch_size=batch_size)

dict_dataloader = {"train_loader": train_loader, "val_loader": val_loader}
model_type = "vgg_16"
model = Model_VGG(model_type)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

wandb.init(
    project="Digit_Recognizer",
    group="VGG",
    name="{}_{}_{}_{}".format("vgg_16", batch_size, epochs, lr),
)

wandb.watch(model)

training_model = Training_Model(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    dict_dataloader=dict_dataloader,
)

for epoch in range(epochs):
    training_model.train(epoch)
    training_model.validation(epoch)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy_score,▁▅█
train_f1_score,▁▅█
train_loss,█▄▁
train_precision_score,▁▅█
val_accuracy_score,▁▇█
val_f1_score,▁▇█
val_loss,█▃▁
val_precision_score,▁▇█

0,1
train_accuracy_score,0.9157
train_f1_score,0.9151
train_loss,0.3119
train_precision_score,0.9157
val_accuracy_score,0.969
val_f1_score,0.9689
val_loss,0.1432
val_precision_score,0.9695




100%|██████████| 263/263 [00:55<00:00,  4.77it/s]


train_loss: 2.1088, train_precision: 0.2039, train_accuracy: 0.2108, train_f1_score: 0.1985


100%|██████████| 66/66 [00:07<00:00,  9.22it/s]
  _warn_prf(average, modifier, msg_start, len(result))


val_loss: 1.6357, val_precision: 0.2878, val_accuracy: 0.2861, val_f1_score: 0.2318


100%|██████████| 263/263 [00:45<00:00,  5.73it/s]


train_loss: 0.8564, train_precision: 0.68, train_accuracy: 0.6735, train_f1_score: 0.666


100%|██████████| 66/66 [00:07<00:00,  9.00it/s]


val_loss: 0.5227, val_precision: 0.85, val_accuracy: 0.8527, val_f1_score: 0.8359


100%|██████████| 263/263 [00:45<00:00,  5.75it/s]


train_loss: 0.1595, train_precision: 0.9588, train_accuracy: 0.9591, train_f1_score: 0.9588


100%|██████████| 66/66 [00:07<00:00,  8.67it/s]


val_loss: 0.0947, val_precision: 0.9797, val_accuracy: 0.9794, val_f1_score: 0.9794


100%|██████████| 263/263 [00:46<00:00,  5.65it/s]


train_loss: 0.1192, train_precision: 0.9705, train_accuracy: 0.9706, train_f1_score: 0.9706


100%|██████████| 66/66 [00:07<00:00,  8.89it/s]


val_loss: 0.0902, val_precision: 0.9781, val_accuracy: 0.978, val_f1_score: 0.978


100%|██████████| 263/263 [00:46<00:00,  5.68it/s]


train_loss: 0.0713, train_precision: 0.9833, train_accuracy: 0.9834, train_f1_score: 0.9834


100%|██████████| 66/66 [00:07<00:00,  8.87it/s]


val_loss: 0.0611, val_precision: 0.9824, val_accuracy: 0.9821, val_f1_score: 0.9821


100%|██████████| 263/263 [00:46<00:00,  5.70it/s]


train_loss: 0.0728, train_precision: 0.9822, train_accuracy: 0.9823, train_f1_score: 0.9823


100%|██████████| 66/66 [00:07<00:00,  9.05it/s]


val_loss: 0.0501, val_precision: 0.9877, val_accuracy: 0.9876, val_f1_score: 0.9877


100%|██████████| 263/263 [00:45<00:00,  5.76it/s]


train_loss: 0.0698, train_precision: 0.9837, train_accuracy: 0.9837, train_f1_score: 0.9838


100%|██████████| 66/66 [00:07<00:00,  8.56it/s]


val_loss: 0.124, val_precision: 0.9765, val_accuracy: 0.9755, val_f1_score: 0.9756


100%|██████████| 263/263 [00:45<00:00,  5.75it/s]


train_loss: 0.0907, train_precision: 0.9794, train_accuracy: 0.9795, train_f1_score: 0.9794


100%|██████████| 66/66 [00:07<00:00,  9.07it/s]


val_loss: 0.0783, val_precision: 0.9869, val_accuracy: 0.987, val_f1_score: 0.987


100%|██████████| 263/263 [00:45<00:00,  5.78it/s]


train_loss: 0.067, train_precision: 0.985, train_accuracy: 0.9851, train_f1_score: 0.985


100%|██████████| 66/66 [00:07<00:00,  9.01it/s]


val_loss: 0.1, val_precision: 0.9798, val_accuracy: 0.979, val_f1_score: 0.9793


100%|██████████| 263/263 [00:46<00:00,  5.71it/s]


train_loss: 0.0729, train_precision: 0.984, train_accuracy: 0.9841, train_f1_score: 0.984


100%|██████████| 66/66 [00:07<00:00,  8.48it/s]


val_loss: 0.1063, val_precision: 0.982, val_accuracy: 0.9818, val_f1_score: 0.9818


100%|██████████| 263/263 [00:46<00:00,  5.65it/s]


train_loss: 0.1218, train_precision: 0.9733, train_accuracy: 0.9733, train_f1_score: 0.9731


100%|██████████| 66/66 [00:07<00:00,  9.11it/s]


val_loss: 0.0702, val_precision: 0.9833, val_accuracy: 0.9835, val_f1_score: 0.9834


100%|██████████| 263/263 [00:46<00:00,  5.62it/s]


train_loss: 0.0535, train_precision: 0.989, train_accuracy: 0.9891, train_f1_score: 0.989


100%|██████████| 66/66 [00:07<00:00,  9.28it/s]


val_loss: 0.0559, val_precision: 0.989, val_accuracy: 0.989, val_f1_score: 0.9889


100%|██████████| 263/263 [00:46<00:00,  5.70it/s]


train_loss: 0.0376, train_precision: 0.9919, train_accuracy: 0.9919, train_f1_score: 0.9919


100%|██████████| 66/66 [00:07<00:00,  9.00it/s]


val_loss: 0.0734, val_precision: 0.9874, val_accuracy: 0.9875, val_f1_score: 0.9873


100%|██████████| 263/263 [00:45<00:00,  5.76it/s]


train_loss: 0.089, train_precision: 0.9823, train_accuracy: 0.9825, train_f1_score: 0.9824


100%|██████████| 66/66 [00:07<00:00,  8.65it/s]


val_loss: 0.0858, val_precision: 0.9868, val_accuracy: 0.9869, val_f1_score: 0.9867


100%|██████████| 263/263 [00:45<00:00,  5.79it/s]


train_loss: 0.0371, train_precision: 0.993, train_accuracy: 0.993, train_f1_score: 0.993


100%|██████████| 66/66 [00:07<00:00,  9.19it/s]


val_loss: 0.0818, val_precision: 0.982, val_accuracy: 0.9812, val_f1_score: 0.9813


100%|██████████| 263/263 [00:46<00:00,  5.71it/s]


train_loss: 0.0404, train_precision: 0.9919, train_accuracy: 0.9919, train_f1_score: 0.9919


100%|██████████| 66/66 [00:07<00:00,  8.78it/s]


val_loss: 0.0581, val_precision: 0.9896, val_accuracy: 0.9895, val_f1_score: 0.9894


100%|██████████| 263/263 [00:46<00:00,  5.70it/s]


train_loss: 0.0344, train_precision: 0.9938, train_accuracy: 0.9939, train_f1_score: 0.9938


100%|██████████| 66/66 [00:07<00:00,  8.37it/s]


val_loss: 0.0878, val_precision: 0.9852, val_accuracy: 0.985, val_f1_score: 0.985


100%|██████████| 263/263 [00:45<00:00,  5.79it/s]


train_loss: 0.2391, train_precision: 0.9507, train_accuracy: 0.9504, train_f1_score: 0.9502


100%|██████████| 66/66 [00:07<00:00,  9.11it/s]


val_loss: 0.1289, val_precision: 0.9764, val_accuracy: 0.9763, val_f1_score: 0.9763


100%|██████████| 263/263 [00:45<00:00,  5.77it/s]


train_loss: 0.0827, train_precision: 0.983, train_accuracy: 0.983, train_f1_score: 0.9829


100%|██████████| 66/66 [00:07<00:00,  9.26it/s]


val_loss: 0.0684, val_precision: 0.9862, val_accuracy: 0.9863, val_f1_score: 0.9861


100%|██████████| 263/263 [00:46<00:00,  5.70it/s]


train_loss: 0.0345, train_precision: 0.9921, train_accuracy: 0.9922, train_f1_score: 0.9921


100%|██████████| 66/66 [00:07<00:00,  8.76it/s]


val_loss: 0.0757, val_precision: 0.988, val_accuracy: 0.9879, val_f1_score: 0.9879


100%|██████████| 263/263 [00:45<00:00,  5.80it/s]


train_loss: 0.0433, train_precision: 0.9914, train_accuracy: 0.9914, train_f1_score: 0.9914


100%|██████████| 66/66 [00:07<00:00,  9.17it/s]


val_loss: 0.0755, val_precision: 0.9863, val_accuracy: 0.9861, val_f1_score: 0.9861


100%|██████████| 263/263 [00:46<00:00,  5.71it/s]


train_loss: 0.0324, train_precision: 0.9927, train_accuracy: 0.9927, train_f1_score: 0.9927


100%|██████████| 66/66 [00:07<00:00,  9.20it/s]


val_loss: 0.0771, val_precision: 0.9869, val_accuracy: 0.9869, val_f1_score: 0.9869


100%|██████████| 263/263 [00:46<00:00,  5.67it/s]


train_loss: 0.0296, train_precision: 0.9939, train_accuracy: 0.9938, train_f1_score: 0.9939


100%|██████████| 66/66 [00:07<00:00,  8.89it/s]


val_loss: 0.0814, val_precision: 0.9883, val_accuracy: 0.9882, val_f1_score: 0.9881


100%|██████████| 263/263 [00:45<00:00,  5.81it/s]


train_loss: 0.0332, train_precision: 0.9936, train_accuracy: 0.9936, train_f1_score: 0.9936


100%|██████████| 66/66 [00:07<00:00,  9.07it/s]


val_loss: 0.0611, val_precision: 0.9893, val_accuracy: 0.9893, val_f1_score: 0.9893


100%|██████████| 263/263 [00:45<00:00,  5.76it/s]


train_loss: 0.0201, train_precision: 0.9961, train_accuracy: 0.9961, train_f1_score: 0.9961


100%|██████████| 66/66 [00:07<00:00,  9.32it/s]


val_loss: 0.1298, val_precision: 0.9859, val_accuracy: 0.9857, val_f1_score: 0.9856


100%|██████████| 263/263 [00:45<00:00,  5.74it/s]


train_loss: 0.0894, train_precision: 0.9805, train_accuracy: 0.9804, train_f1_score: 0.9805


100%|██████████| 66/66 [00:07<00:00,  9.27it/s]


val_loss: 0.1074, val_precision: 0.984, val_accuracy: 0.9837, val_f1_score: 0.9837


100%|██████████| 263/263 [00:45<00:00,  5.78it/s]


train_loss: 0.0468, train_precision: 0.9907, train_accuracy: 0.9908, train_f1_score: 0.9907


100%|██████████| 66/66 [00:07<00:00,  9.11it/s]


val_loss: 0.0831, val_precision: 0.9871, val_accuracy: 0.9868, val_f1_score: 0.9868


100%|██████████| 263/263 [00:44<00:00,  5.88it/s]


train_loss: 0.0264, train_precision: 0.9951, train_accuracy: 0.9951, train_f1_score: 0.9951


100%|██████████| 66/66 [00:07<00:00,  9.03it/s]


val_loss: 0.0849, val_precision: 0.9901, val_accuracy: 0.9901, val_f1_score: 0.99


100%|██████████| 263/263 [00:45<00:00,  5.83it/s]


train_loss: 0.015, train_precision: 0.9969, train_accuracy: 0.9969, train_f1_score: 0.9969


100%|██████████| 66/66 [00:07<00:00,  9.40it/s]


val_loss: 0.057, val_precision: 0.9903, val_accuracy: 0.9904, val_f1_score: 0.9902


100%|██████████| 263/263 [00:44<00:00,  5.86it/s]


train_loss: 0.0463, train_precision: 0.9912, train_accuracy: 0.9912, train_f1_score: 0.9912


100%|██████████| 66/66 [00:07<00:00,  8.80it/s]


val_loss: 0.1003, val_precision: 0.9901, val_accuracy: 0.9901, val_f1_score: 0.99


100%|██████████| 263/263 [00:44<00:00,  5.85it/s]


train_loss: 0.026, train_precision: 0.9949, train_accuracy: 0.9949, train_f1_score: 0.9949


100%|██████████| 66/66 [00:06<00:00,  9.46it/s]


val_loss: 0.0564, val_precision: 0.9897, val_accuracy: 0.9898, val_f1_score: 0.9896


100%|██████████| 263/263 [00:44<00:00,  5.89it/s]


train_loss: 0.0462, train_precision: 0.9926, train_accuracy: 0.9926, train_f1_score: 0.9926


100%|██████████| 66/66 [00:06<00:00,  9.43it/s]


val_loss: 0.1787, val_precision: 0.9712, val_accuracy: 0.9704, val_f1_score: 0.9702


100%|██████████| 263/263 [00:44<00:00,  5.88it/s]


train_loss: 0.0435, train_precision: 0.9914, train_accuracy: 0.9915, train_f1_score: 0.9914


100%|██████████| 66/66 [00:06<00:00,  9.43it/s]


val_loss: 0.0837, val_precision: 0.9894, val_accuracy: 0.9895, val_f1_score: 0.9894


100%|██████████| 263/263 [00:44<00:00,  5.85it/s]


train_loss: 0.0225, train_precision: 0.9954, train_accuracy: 0.9954, train_f1_score: 0.9954


100%|██████████| 66/66 [00:07<00:00,  9.25it/s]


val_loss: 0.0738, val_precision: 0.9897, val_accuracy: 0.9896, val_f1_score: 0.9896


100%|██████████| 263/263 [00:45<00:00,  5.76it/s]


train_loss: 0.026, train_precision: 0.9953, train_accuracy: 0.9953, train_f1_score: 0.9953


100%|██████████| 66/66 [00:07<00:00,  9.24it/s]


val_loss: 0.0663, val_precision: 0.9915, val_accuracy: 0.9915, val_f1_score: 0.9915


100%|██████████| 263/263 [00:44<00:00,  5.85it/s]


train_loss: 0.0178, train_precision: 0.9963, train_accuracy: 0.9963, train_f1_score: 0.9963


100%|██████████| 66/66 [00:06<00:00,  9.54it/s]


val_loss: 0.0903, val_precision: 0.9877, val_accuracy: 0.9877, val_f1_score: 0.9875


100%|██████████| 263/263 [00:44<00:00,  5.86it/s]


train_loss: 0.104, train_precision: 0.9793, train_accuracy: 0.9794, train_f1_score: 0.9793


100%|██████████| 66/66 [00:07<00:00,  9.36it/s]


val_loss: 0.091, val_precision: 0.9854, val_accuracy: 0.9856, val_f1_score: 0.9854


100%|██████████| 263/263 [00:44<00:00,  5.86it/s]


train_loss: 0.0339, train_precision: 0.9937, train_accuracy: 0.9937, train_f1_score: 0.9937


100%|██████████| 66/66 [00:07<00:00,  9.05it/s]


val_loss: 0.0747, val_precision: 0.9896, val_accuracy: 0.9896, val_f1_score: 0.9895


100%|██████████| 263/263 [00:45<00:00,  5.82it/s]


train_loss: 0.012, train_precision: 0.9977, train_accuracy: 0.9977, train_f1_score: 0.9977


100%|██████████| 66/66 [00:06<00:00,  9.53it/s]


val_loss: 0.0918, val_precision: 0.9913, val_accuracy: 0.9914, val_f1_score: 0.9913


100%|██████████| 263/263 [00:45<00:00,  5.82it/s]


train_loss: 0.0122, train_precision: 0.9973, train_accuracy: 0.9974, train_f1_score: 0.9973


100%|██████████| 66/66 [00:07<00:00,  9.43it/s]


val_loss: 0.1015, val_precision: 0.9905, val_accuracy: 0.9907, val_f1_score: 0.9906


100%|██████████| 263/263 [00:44<00:00,  5.86it/s]


train_loss: 0.038, train_precision: 0.9929, train_accuracy: 0.993, train_f1_score: 0.993


100%|██████████| 66/66 [00:07<00:00,  9.41it/s]


val_loss: 0.0763, val_precision: 0.9843, val_accuracy: 0.9836, val_f1_score: 0.9837


100%|██████████| 263/263 [00:45<00:00,  5.83it/s]


train_loss: 0.0368, train_precision: 0.9925, train_accuracy: 0.9926, train_f1_score: 0.9925


100%|██████████| 66/66 [00:07<00:00,  9.31it/s]


val_loss: 0.0499, val_precision: 0.9912, val_accuracy: 0.9912, val_f1_score: 0.9911


100%|██████████| 263/263 [00:44<00:00,  5.86it/s]


train_loss: 0.0174, train_precision: 0.9965, train_accuracy: 0.9965, train_f1_score: 0.9965


100%|██████████| 66/66 [00:06<00:00,  9.44it/s]


val_loss: 0.0612, val_precision: 0.9908, val_accuracy: 0.9907, val_f1_score: 0.9907


100%|██████████| 263/263 [00:44<00:00,  5.88it/s]


train_loss: 0.0226, train_precision: 0.9965, train_accuracy: 0.9965, train_f1_score: 0.9965


100%|██████████| 66/66 [00:07<00:00,  8.93it/s]


val_loss: 0.0563, val_precision: 0.9919, val_accuracy: 0.992, val_f1_score: 0.9919


100%|██████████| 263/263 [00:44<00:00,  5.87it/s]


train_loss: 0.0124, train_precision: 0.9974, train_accuracy: 0.9974, train_f1_score: 0.9974


100%|██████████| 66/66 [00:07<00:00,  9.13it/s]


val_loss: 0.063, val_precision: 0.9915, val_accuracy: 0.9914, val_f1_score: 0.9914


100%|██████████| 263/263 [00:45<00:00,  5.78it/s]


train_loss: 0.0177, train_precision: 0.997, train_accuracy: 0.997, train_f1_score: 0.997


100%|██████████| 66/66 [00:07<00:00,  9.40it/s]


val_loss: 0.0616, val_precision: 0.99, val_accuracy: 0.99, val_f1_score: 0.99


100%|██████████| 263/263 [00:45<00:00,  5.79it/s]


train_loss: 0.0304, train_precision: 0.9955, train_accuracy: 0.9955, train_f1_score: 0.9955


100%|██████████| 66/66 [00:07<00:00,  9.25it/s]


val_loss: 0.0764, val_precision: 0.99, val_accuracy: 0.9902, val_f1_score: 0.9901


100%|██████████| 263/263 [00:45<00:00,  5.84it/s]


train_loss: 0.0138, train_precision: 0.9974, train_accuracy: 0.9974, train_f1_score: 0.9974


100%|██████████| 66/66 [00:07<00:00,  9.17it/s]


val_loss: 0.0712, val_precision: 0.9903, val_accuracy: 0.9904, val_f1_score: 0.9903


100%|██████████| 263/263 [00:45<00:00,  5.82it/s]


train_loss: 0.0117, train_precision: 0.998, train_accuracy: 0.998, train_f1_score: 0.998


100%|██████████| 66/66 [00:07<00:00,  9.03it/s]


val_loss: 0.1061, val_precision: 0.9898, val_accuracy: 0.9899, val_f1_score: 0.9897


100%|██████████| 263/263 [00:45<00:00,  5.75it/s]


train_loss: 0.0232, train_precision: 0.996, train_accuracy: 0.996, train_f1_score: 0.996


100%|██████████| 66/66 [00:07<00:00,  9.00it/s]


val_loss: 0.172, val_precision: 0.9772, val_accuracy: 0.9755, val_f1_score: 0.9756


In [None]:
model.load_state_dict(torch.load("/kaggle/working/vgg_16_f1_0.9902.pt"))

In [None]:
model = model.to(device)

In [None]:
DF_test = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")

In [None]:
from torchvision import transforms
DF_submission = pd.DataFrame(columns=["ImageId", "Label"])
transforms = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,)),
    ]
)
for row in tqdm(range(DF_test.shape[0])):
    img = DF_test.iloc[row, :].values
    img = img.reshape(28, 28).astype(np.uint8)
    # 3 channel
    img = np.stack((img,) * 3, axis=-1)
    img = transforms(img)
    
    img = img.to(device)
    img = img.unsqueeze(0)
    with torch.no_grad():
        output = model(img)
        _, predicted = torch.max(output.data, 1)
        DF_test.iloc[row, :] = predicted.item()
        print(predicted.item())
        DF_submission = DF_submission.append( {"ImageId": row+1, "Label": predicted.item()}, ignore_index=True)
DF_submission.to_csv("/kaggle/working/vgg_16_f1_50_0.9915_submission.csv", index=False)
