In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
    random.seed(seed)
    np.random.seed(seed)

In [None]:
seed = 114514
set_seed(seed = seed)

In [1]:
!nvidia-smi

Fri Aug  5 11:41:51 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 512.78       Driver Version: 512.78       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   58C    P8     6W /  N/A |    603MiB /  6144MiB |     36%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
class Paddy(Dataset):
    def __init__(self, files_name, transform, mode = "train"):
        super(Paddy).__init__()
        self.tfm = transform
        self.mode = mode
        self.paddy_dict = {'bacterial_leaf_blight': 0, 'bacterial_leaf_streak': 1, 'bacterial_panicle_blight': 2, \
            'blast': 3, 'brown_spot': 4, 'dead_heart': 5, 'downy_mildew': 6, 'hispa': 7, 'normal': 8, 'tungro': 9}
        self.files_name = files_name

    def __len__(self):
        return len(self.files_name)
    
    def __getitem__(self, idx):
        file_name = self.files_name[idx]
        im = self.tfm(Image.open(file_name))
        if self.mode == "train":
            label = self.paddy_dict[file_name.split("/")[-2]]
        else:
            label = -1
        return im, label

In [5]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(180),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [6]:
### Configuration
batch_size = 64
n_epochs = 120
learning_rate = 0.001
K_fold = 5
early_stopping_count = 24
best_model_save_path = "./models"

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


In [7]:
### Train_path is "paddy-disease-classification"; Test_path is "test_images"
Train_path = './'
Test_path = './test_images'

train_files_name = []
train_info = pd.read_csv(f'{Train_path}train.csv')
for index, row in train_info.iterrows():
    train_files_name.append(f'{Train_path}train_images/{row["label"]}/{row["image_id"]}')
random.shuffle(train_files_name)
    
test_files_name = []
for image in os.listdir(Test_path):
    test_files_name.append(f'{Test_path}/{image}')

In [8]:
Train_set = []
Val_set = []
set_num = len(train_files_name) // K_fold
for i in range(K_fold):
    Train_set.append(train_files_name[: i * set_num] + train_files_name[(i + 1) * set_num :])
    Val_set.append(train_files_name[i * set_num : (i + 1) * set_num])

In [20]:
def train():
    for i in range(K_fold):
        print(f"###################### Processing fold: {i} ###################### \n")
        model = models.resnet34(pretrained = True)
        model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(model.fc.in_features, 10)
        )
        model = model.to(device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay=0.00001)
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 12, T_mult = 1)
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr = 1e-5, max_lr = 0.001, step_size_up = 4, mode="triangular2")
        criterion = nn.CrossEntropyLoss()

        train_data = Paddy(files_name = Train_set[i], transform = train_transform, mode = "train")
        val_data = Paddy(files_name = Val_set[i], transform = test_transform, mode = "train")

        train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers = 4, pin_memory = True)
        val_loader = DataLoader(val_data, batch_size = batch_size, shuffle = True, num_workers = 4, pin_memory = True)
        
        with open('./record.txt', 'a') as f:
            f.write(f'Fold: {i + 1}\n')

        Acc = 0.
        stopping_count = 0
        for epoch in range(n_epochs):
            train_loss = []
            train_acc = []
            val_loss = []
            val_acc = []

            model.train()
            pbar = tqdm(train_loader, ncols = 120)

            for images, labels in pbar:
                optimizer.zero_grad()

                images, labels = images.to(device), labels.to(device)

                prediction = model(images)
                loss = criterion(prediction, labels)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm = 20)
                optimizer.step()

                train_loss.append(loss.detach().item())

                ### [batch_size * 10]
                train_acc.append((prediction.argmax(dim = 1).flatten() == labels.flatten()).float().mean().item())

                pbar.set_description(f"Processing train epoch [{epoch + 1} / {n_epochs}]")
                pbar.set_postfix({'Train loss': '{0:1.5f}'.format(sum(train_loss) / len(train_loss)), \
                                  'Train accuracy': '{0:1.5f}'.format(sum(train_acc) / len(train_acc))})    

            scheduler.step()


            ### Validation Part
            model.eval()
            pbar = tqdm(val_loader, ncols = 120)

            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)

                with torch.no_grad():
                    prediction = model(images)
                    loss = criterion(prediction, labels)

                val_loss.append(loss.detach().item())

                ### [batch_size * 10]
                val_acc.append((prediction.argmax(dim = 1).flatten() == labels.flatten()).float().mean().item())

                pbar.set_description(f"Processing validation epoch [{epoch + 1} / {n_epochs}]")
                pbar.set_postfix({'Val loss': '{0:1.5f}'.format(sum(val_loss) / len(val_loss)), \
                                  'Val accuracy': '{0:1.5f}'.format(sum(val_acc) / len(val_acc))})

            Val_acc = sum(val_acc) / len(val_acc)

            with open('./record.txt', 'a') as f:
                f.write(f"Processing train epoch [{epoch + 1} / {n_epochs}]\n")
                f.write(f'Train loss: {sum(train_loss) / len(train_loss)}\n')
                f.write(f'Train acc: {sum(train_acc) / len(train_acc)}\n')
                f.write(f'Val loss: {sum(val_loss) / len(val_loss)}\n')
                f.write(f'Val acc: {sum(val_acc) / len(val_acc)}\n')

            if Val_acc > Acc:
                Acc = Val_acc
                print(f'Find the best model in Fold {i + 1}')
                model_path = best_model_save_path + 'Fold_' + str(i + 1) + '.ckpt'
                torch.save(model.state_dict(), model_path)
                stopping_count = 0
            else:
                stopping_count += 1
                if stopping_count >= early_stopping_count:
                    print("###################### Early stopped ###################### \n")
                    break

In [None]:
train()

###################### Processing fold: 0 ###################### 



Processing train epoch [1 / 120]: 100%|███| 131/131 [00:51<00:00,  2.55it/s, Train loss=1.19140, Train accuracy=0.59681]
Processing validation epoch [1 / 120]: 100%|████| 33/33 [00:12<00:00,  2.61it/s, Val loss=1.38126, Val accuracy=0.58939]


Find the best model in Fold 1


Processing train epoch [2 / 120]: 100%|███| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.74685, Train accuracy=0.74571]
Processing validation epoch [2 / 120]: 100%|████| 33/33 [00:06<00:00,  5.10it/s, Val loss=0.79601, Val accuracy=0.74277]


Find the best model in Fold 1


Processing train epoch [3 / 120]: 100%|███| 131/131 [00:25<00:00,  5.17it/s, Train loss=0.59342, Train accuracy=0.80129]
Processing validation epoch [3 / 120]: 100%|████| 33/33 [00:06<00:00,  5.17it/s, Val loss=1.46726, Val accuracy=0.61078]
Processing train epoch [4 / 120]: 100%|███| 131/131 [00:25<00:00,  5.15it/s, Train loss=0.41817, Train accuracy=0.85790]
Processing validation epoch [4 / 120]: 100%|████| 33/33 [00:06<00:00,  5.48it/s, Val loss=0.57628, Val accuracy=0.80741]


Find the best model in Fold 1


Processing train epoch [5 / 120]: 100%|███| 131/131 [00:26<00:00,  5.02it/s, Train loss=0.37980, Train accuracy=0.87647]
Processing validation epoch [5 / 120]: 100%|████| 33/33 [00:06<00:00,  5.43it/s, Val loss=0.41542, Val accuracy=0.85810]


Find the best model in Fold 1


Processing train epoch [6 / 120]: 100%|███| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.29498, Train accuracy=0.90498]
Processing validation epoch [6 / 120]: 100%|████| 33/33 [00:06<00:00,  5.40it/s, Val loss=0.36612, Val accuracy=0.88740]


Find the best model in Fold 1


Processing train epoch [7 / 120]: 100%|███| 131/131 [00:26<00:00,  5.03it/s, Train loss=0.21118, Train accuracy=0.93146]
Processing validation epoch [7 / 120]: 100%|████| 33/33 [00:06<00:00,  5.19it/s, Val loss=0.36189, Val accuracy=0.88592]
Processing train epoch [8 / 120]: 100%|███| 131/131 [00:25<00:00,  5.13it/s, Train loss=0.16267, Train accuracy=0.94633]
Processing validation epoch [8 / 120]: 100%|████| 33/33 [00:06<00:00,  5.41it/s, Val loss=0.19839, Val accuracy=0.94558]


Find the best model in Fold 1


Processing train epoch [9 / 120]: 100%|███| 131/131 [00:25<00:00,  5.13it/s, Train loss=0.12734, Train accuracy=0.96100]
Processing validation epoch [9 / 120]: 100%|████| 33/33 [00:06<00:00,  5.23it/s, Val loss=0.16610, Val accuracy=0.94655]


Find the best model in Fold 1


Processing train epoch [10 / 120]: 100%|██| 131/131 [00:25<00:00,  5.19it/s, Train loss=0.09086, Train accuracy=0.97141]
Processing validation epoch [10 / 120]: 100%|███| 33/33 [00:06<00:00,  5.09it/s, Val loss=0.13429, Val accuracy=0.96404]


Find the best model in Fold 1


Processing train epoch [11 / 120]: 100%|██| 131/131 [00:25<00:00,  5.15it/s, Train loss=0.07045, Train accuracy=0.97913]
Processing validation epoch [11 / 120]: 100%|███| 33/33 [00:06<00:00,  5.28it/s, Val loss=0.12685, Val accuracy=0.96212]
Processing train epoch [12 / 120]: 100%|██| 131/131 [00:25<00:00,  5.13it/s, Train loss=0.06973, Train accuracy=0.97774]
Processing validation epoch [12 / 120]: 100%|███| 33/33 [00:06<00:00,  5.29it/s, Val loss=0.12711, Val accuracy=0.96404]
Processing train epoch [13 / 120]: 100%|██| 131/131 [00:25<00:00,  5.21it/s, Train loss=0.46435, Train accuracy=0.85059]
Processing validation epoch [13 / 120]: 100%|███| 33/33 [00:06<00:00,  5.30it/s, Val loss=0.67486, Val accuracy=0.79515]
Processing train epoch [14 / 120]: 100%|██| 131/131 [00:25<00:00,  5.07it/s, Train loss=0.35697, Train accuracy=0.88251]
Processing validation epoch [14 / 120]: 100%|███| 33/33 [00:06<00:00,  5.50it/s, Val loss=0.41848, Val accuracy=0.86088]
Processing train epoch [15 / 120

Find the best model in Fold 1


Processing train epoch [22 / 120]: 100%|██| 131/131 [00:25<00:00,  5.18it/s, Train loss=0.05847, Train accuracy=0.98032]
Processing validation epoch [22 / 120]: 100%|███| 33/33 [00:06<00:00,  5.08it/s, Val loss=0.09771, Val accuracy=0.97351]


Find the best model in Fold 1


Processing train epoch [23 / 120]: 100%|██| 131/131 [00:25<00:00,  5.06it/s, Train loss=0.04503, Train accuracy=0.98581]
Processing validation epoch [23 / 120]: 100%|███| 33/33 [00:06<00:00,  5.36it/s, Val loss=0.09653, Val accuracy=0.97446]


Find the best model in Fold 1


Processing train epoch [24 / 120]: 100%|██| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.04161, Train accuracy=0.98700]
Processing validation epoch [24 / 120]: 100%|███| 33/33 [00:06<00:00,  5.40it/s, Val loss=0.09740, Val accuracy=0.97446]
Processing train epoch [25 / 120]: 100%|██| 131/131 [00:25<00:00,  5.08it/s, Train loss=0.28410, Train accuracy=0.90856]
Processing validation epoch [25 / 120]: 100%|███| 33/33 [00:05<00:00,  5.53it/s, Val loss=0.44108, Val accuracy=0.86331]
Processing train epoch [26 / 120]: 100%|██| 131/131 [00:25<00:00,  5.15it/s, Train loss=0.34279, Train accuracy=0.88844]
Processing validation epoch [26 / 120]: 100%|███| 33/33 [00:06<00:00,  5.30it/s, Val loss=0.36457, Val accuracy=0.89639]
Processing train epoch [27 / 120]: 100%|██| 131/131 [00:25<00:00,  5.06it/s, Train loss=0.21076, Train accuracy=0.92915]
Processing validation epoch [27 / 120]: 100%|███| 33/33 [00:06<00:00,  5.04it/s, Val loss=0.27452, Val accuracy=0.91383]
Processing train epoch [28 / 120

Find the best model in Fold 1


Processing train epoch [46 / 120]: 100%|██| 131/131 [00:25<00:00,  5.04it/s, Train loss=0.04634, Train accuracy=0.98441]
Processing validation epoch [46 / 120]: 100%|███| 33/33 [00:06<00:00,  5.34it/s, Val loss=0.10524, Val accuracy=0.97499]
Processing train epoch [47 / 120]: 100%|██| 131/131 [00:25<00:00,  5.10it/s, Train loss=0.03563, Train accuracy=0.98895]
Processing validation epoch [47 / 120]: 100%|███| 33/33 [00:06<00:00,  5.25it/s, Val loss=0.08995, Val accuracy=0.97777]


Find the best model in Fold 1


Processing train epoch [48 / 120]: 100%|██| 131/131 [00:25<00:00,  5.07it/s, Train loss=0.02986, Train accuracy=0.98915]
Processing validation epoch [48 / 120]: 100%|███| 33/33 [00:06<00:00,  5.17it/s, Val loss=0.09268, Val accuracy=0.97635]
Processing train epoch [49 / 120]: 100%|██| 131/131 [00:26<00:00,  5.00it/s, Train loss=0.12589, Train accuracy=0.96537]
Processing validation epoch [49 / 120]: 100%|███| 33/33 [00:06<00:00,  5.29it/s, Val loss=0.35144, Val accuracy=0.90492]
Processing train epoch [50 / 120]: 100%|██| 131/131 [00:25<00:00,  5.14it/s, Train loss=0.25794, Train accuracy=0.91794]
Processing validation epoch [50 / 120]: 100%|███| 33/33 [00:06<00:00,  5.09it/s, Val loss=0.42425, Val accuracy=0.88169]
Processing train epoch [51 / 120]: 100%|██| 131/131 [00:25<00:00,  5.18it/s, Train loss=0.19081, Train accuracy=0.93802]
Processing validation epoch [51 / 120]: 100%|███| 33/33 [00:06<00:00,  5.45it/s, Val loss=0.23354, Val accuracy=0.92578]
Processing train epoch [52 / 120

###################### Early stopped ###################### 

###################### Processing fold: 1 ###################### 



Processing train epoch [1 / 120]: 100%|███| 131/131 [00:25<00:00,  5.14it/s, Train loss=1.15697, Train accuracy=0.61470]
Processing validation epoch [1 / 120]: 100%|████| 33/33 [00:06<00:00,  5.31it/s, Val loss=0.92645, Val accuracy=0.68213]


Find the best model in Fold 2


Processing train epoch [2 / 120]: 100%|███| 131/131 [00:25<00:00,  5.05it/s, Train loss=0.74002, Train accuracy=0.75573]
Processing validation epoch [2 / 120]: 100%|████| 33/33 [00:06<00:00,  5.37it/s, Val loss=0.68681, Val accuracy=0.77248]


Find the best model in Fold 2


Processing train epoch [3 / 120]: 100%|███| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.56672, Train accuracy=0.81314]
Processing validation epoch [3 / 120]: 100%|████| 33/33 [00:06<00:00,  5.34it/s, Val loss=0.50191, Val accuracy=0.83871]


Find the best model in Fold 2


Processing train epoch [4 / 120]: 100%|███| 131/131 [00:26<00:00,  5.01it/s, Train loss=0.45770, Train accuracy=0.85095]
Processing validation epoch [4 / 120]: 100%|████| 33/33 [00:06<00:00,  5.21it/s, Val loss=0.54739, Val accuracy=0.81404]
Processing train epoch [5 / 120]: 100%|███| 131/131 [00:26<00:00,  4.94it/s, Train loss=0.32372, Train accuracy=0.89353]
Processing validation epoch [5 / 120]: 100%|████| 33/33 [00:06<00:00,  5.48it/s, Val loss=0.45838, Val accuracy=0.85804]


Find the best model in Fold 2


Processing train epoch [6 / 120]: 100%|███| 131/131 [00:25<00:00,  5.06it/s, Train loss=0.27732, Train accuracy=0.90824]
Processing validation epoch [6 / 120]: 100%|████| 33/33 [00:06<00:00,  5.19it/s, Val loss=0.49124, Val accuracy=0.85431]
Processing train epoch [7 / 120]: 100%|███| 131/131 [00:25<00:00,  5.05it/s, Train loss=0.23846, Train accuracy=0.92048]
Processing validation epoch [7 / 120]: 100%|████| 33/33 [00:06<00:00,  5.43it/s, Val loss=0.23384, Val accuracy=0.92477]


Find the best model in Fold 2


Processing train epoch [8 / 120]: 100%|███| 131/131 [00:25<00:00,  5.05it/s, Train loss=0.15800, Train accuracy=0.94843]
Processing validation epoch [8 / 120]: 100%|████| 33/33 [00:06<00:00,  5.33it/s, Val loss=0.19322, Val accuracy=0.94271]


Find the best model in Fold 2


Processing train epoch [9 / 120]: 100%|███| 131/131 [00:25<00:00,  5.10it/s, Train loss=0.12921, Train accuracy=0.96052]
Processing validation epoch [9 / 120]: 100%|████| 33/33 [00:06<00:00,  4.89it/s, Val loss=0.16731, Val accuracy=0.94753]


Find the best model in Fold 2


Processing train epoch [10 / 120]: 100%|██| 131/131 [00:26<00:00,  5.01it/s, Train loss=0.09898, Train accuracy=0.97026]
Processing validation epoch [10 / 120]: 100%|███| 33/33 [00:06<00:00,  5.04it/s, Val loss=0.13273, Val accuracy=0.95410]


Find the best model in Fold 2


Processing train epoch [11 / 120]: 100%|██| 131/131 [00:25<00:00,  5.06it/s, Train loss=0.07720, Train accuracy=0.97400]
Processing validation epoch [11 / 120]: 100%|███| 33/33 [00:06<00:00,  5.20it/s, Val loss=0.13302, Val accuracy=0.96070]


Find the best model in Fold 2


Processing train epoch [12 / 120]: 100%|██| 131/131 [00:25<00:00,  5.10it/s, Train loss=0.06442, Train accuracy=0.97996]
Processing validation epoch [12 / 120]: 100%|███| 33/33 [00:06<00:00,  5.35it/s, Val loss=0.12525, Val accuracy=0.96026]
Processing train epoch [13 / 120]: 100%|██| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.40542, Train accuracy=0.86991]
Processing validation epoch [13 / 120]: 100%|███| 33/33 [00:06<00:00,  5.13it/s, Val loss=1.07551, Val accuracy=0.69912]
Processing train epoch [14 / 120]: 100%|██| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.34583, Train accuracy=0.88546]
Processing validation epoch [14 / 120]: 100%|███| 33/33 [00:06<00:00,  5.11it/s, Val loss=0.38163, Val accuracy=0.87695]
Processing train epoch [15 / 120]: 100%|██| 131/131 [00:25<00:00,  5.12it/s, Train loss=0.30092, Train accuracy=0.90613]
Processing validation epoch [15 / 120]: 100%|███| 33/33 [00:06<00:00,  5.40it/s, Val loss=0.40695, Val accuracy=0.87843]
Processing train epoch [16 / 120

Find the best model in Fold 2


Processing train epoch [22 / 120]: 100%|██| 131/131 [00:26<00:00,  4.98it/s, Train loss=0.06014, Train accuracy=0.98008]
Processing validation epoch [22 / 120]: 100%|███| 33/33 [00:06<00:00,  5.06it/s, Val loss=0.10423, Val accuracy=0.97117]


Find the best model in Fold 2


Processing train epoch [23 / 120]: 100%|██| 131/131 [00:25<00:00,  5.11it/s, Train loss=0.04881, Train accuracy=0.98549]
Processing validation epoch [23 / 120]: 100%|███| 33/33 [00:06<00:00,  5.02it/s, Val loss=0.10373, Val accuracy=0.97159]


Find the best model in Fold 2


Processing train epoch [24 / 120]: 100%|██| 131/131 [00:25<00:00,  5.17it/s, Train loss=0.04174, Train accuracy=0.98628]
Processing validation epoch [24 / 120]: 100%|███| 33/33 [00:06<00:00,  5.09it/s, Val loss=0.10429, Val accuracy=0.96975]
Processing train epoch [25 / 120]: 100%|██| 131/131 [00:25<00:00,  5.07it/s, Train loss=0.27821, Train accuracy=0.91019]
Processing validation epoch [25 / 120]: 100%|███| 33/33 [00:06<00:00,  5.34it/s, Val loss=0.47203, Val accuracy=0.86085]
Processing train epoch [26 / 120]: 100%|██| 131/131 [00:25<00:00,  5.15it/s, Train loss=0.27219, Train accuracy=0.91078]
Processing validation epoch [26 / 120]: 100%|███| 33/33 [00:06<00:00,  5.27it/s, Val loss=0.26451, Val accuracy=0.91625]
Processing train epoch [27 / 120]: 100%|██| 131/131 [00:25<00:00,  5.06it/s, Train loss=0.16503, Train accuracy=0.94541]
Processing validation epoch [27 / 120]: 100%|███| 33/33 [00:06<00:00,  5.20it/s, Val loss=0.25467, Val accuracy=0.91583]
Processing train epoch [28 / 120

In [None]:
def predict(loader, k_fold):
    predictions = []
    for i in range(k_fold):
        tmp_prediction = []

        model_path = best_model_save_path + 'Fold_' + str(i + 1) + '.ckpt'

        model = models.resnet34(pretrained = False)
        model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(model.fc.in_features, 10)
        )
        model.load_state_dict(torch.load(model_path))
        model = model.to(device)
        model.eval()

        pbar = tqdm(loader, ncols = 120)
        for image, _ in pbar:
            image = image.to(device)
            with torch.no_grad():
                prediction = model(image)
                tmp_prediction.append(prediction.detach().cpu())
            pbar.set_description(f"###################### Processing test fold {i + 1} ######################")

        predictions.append(torch.cat(tmp_prediction, dim = 0))

    return sum(predictions)

In [None]:
def test():
    augmented_loaders = []
    test_data = Paddy(files_name = test_files_name, transform = test_transform, mode = "test")
    test_loader = DataLoader(test_data, batch_size = batch_size, shuffle = False, num_workers = 2, pin_memory = True)
    ### Test time augmentation
    augmented_loaders.append(test_loader)
    for _ in range(5):
        augmented_data = Paddy(files_name = test_files_name, transform = train_transform, mode = "test")
        augmented_loaders.append(DataLoader(augmented_data, batch_size = batch_size, shuffle = False, num_workers = 2, pin_memory = True))
        
    predictions = []
    for _, loader in enumerate(augmented_loaders):
        prediction = predict(loader, K_fold)
        predictions.append(prediction)

    augmented_prediction = 0.5 * predictions[0]
    for i in range(1, len(predictions)):
        augmented_prediction += 0.1 * predictions[i]
    
    augmented_pred = augmented_prediction.argmax(dim = 1)

    print(augmented_pred)
    print(np.unique(augmented_pred))

    return augmented_pred

In [None]:
paddy_dict = {0: 'bacterial_leaf_blight', 
              1: 'bacterial_leaf_streak', 
              2: 'bacterial_panicle_blight',
              3: 'blast', 
              4: 'brown_spot', 
              5: 'dead_heart', 
              6: 'downy_mildew', 
              7: 'hispa', 
              8: 'normal', 
              9: 'tungro'
              }

augmented_pred = test()
print(f'Shape of the prediction is: {augmented_pred.shape}')

In [None]:
test_path = os.listdir(Test_path)
with open('./submission.csv', 'w') as f:
    f.write('image_id,label\n')
    for id, label in enumerate(augmented_pred):
        f.write(f'{test_path[id]},{paddy_dict[label.item()]}\n')