In [20]:
import os
import torch
import random
from itertools import permutations
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
filepath = "/kaggle/input/jigsaw/New folder"
permutationfilepath = "permutations500.pth"
modelsave = "/kaggle/working/jigsaw_encoder_3layer.pth"

if not os.path.exists(permutationfilepath):
    allpermutations = list(permutations(range(9)))
    random.shuffle(allpermutations)
    selectedpermutation = torch.tensor(allpermutations[:500])
    torch.save(selectedpermutation, permutationfilepath)
    print("Generated permutations500.pth")

Generated permutations500.pth


In [3]:
patchtransform = transforms.Compose([
    transforms.ToTensor(),
])


In [4]:
class Pretextdataset(Dataset):
    def __init__(self, rootdirectory, permutationfile, transformpatch):
        self.image_paths = [os.path.join(rootdirectory, f) for f in os.listdir(rootdirectory) if f.endswith(".jpg")]
        self.permutations = torch.load(permutationfile)
        self.transform_patch = transformpatch

    def croppatches(self, img):
        img = img.resize((255, 255))
        patches = []
        width, height = img.size
        newwidth, newheight = width // 3, height // 3
        for i in range(3):
            for j in range(3):
                patch = img.crop((j * newwidth, i * newheight, (j + 1) * newwidth, (i + 1) * newheight))
                patches.append(patch)
        return patches

    def __getitem__(self, index):
        img = Image.open(self.image_paths[index]).convert("RGB")
        patches = self.croppatches(img)
        permutation_index = random.randint(0, len(self.permutations) - 1)
        permutation = self.permutations[permutation_index]
        shuffled = [self.transform_patch(patches[i]) for i in permutation]
        return torch.stack(shuffled), permutation_index

    def __len__(self):
        return len(self.image_paths)

In [7]:
class Pretext(nn.Module):
    def __init__(self, outputdimension=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  

            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  

            nn.Conv2d(128, outputdimension, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(4)   
        )

    def forward(self, x):
        return self.encoder(x)

In [11]:
class Fullyconnectedpretext(nn.Module):
    def __init__(self, numpermutations=500):
        super().__init__()
        self.encoder = Pretext()
        self.classifier = nn.Sequential(
            nn.Flatten(),                            
            nn.Linear(28800, 1024),         
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, numpermutations)
        )

    def forward(self, x):
        B, N, C, H, W = x.shape                    
        x = x.view(B * N, C, H, W)                 
        features = self.encoder(x)                   
        features = features.view(B, 28800)                 
        return self.classifier(features)


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("GPU Available:", torch.cuda.is_available())
e=10
batchsize=128
learningrate=1e-3
N = 500
model = Fullyconnectedpretext(numpermutations=N).to(device)

GPU Available: True


In [8]:
dataset = Pretextdataset(filepath, permutationfilepath, patchtransform)
loader = DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=0)
model = Fullyconnectedpretext(numpermutations=N).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learningrate)

In [9]:
from tqdm import tqdm
for epoch in range(e):
    model.train()
    total, correct, totalloss = 0, 0, 0.0

    for batch_index, (x, y) in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}/{e}")):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        totalloss += loss.item()
        predictions = out.argmax(1)
        correct += (predictions == y).sum().item()
        total += y.size(0)

        if batch_index % 10 == 0:
            print(f"Batch {batch_index}/{len(loader)} | Loss: {loss.item():.4f}")

    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{e} | Avg Loss: {totalloss/len(loader):.4f} | Accuracy: {accuracy:.2f}%\n")

torch.save(model.state_dict(), modelsave)
print(f"Model saved to {modelsave}")

Epoch 1/10:   0%|          | 1/425 [00:02<17:38,  2.50s/it]

Batch 0/425 | Loss: 6.2191


Epoch 1/10:   3%|▎         | 11/425 [00:19<11:47,  1.71s/it]

Batch 10/425 | Loss: 6.2100


Epoch 1/10:   5%|▍         | 21/425 [00:36<10:54,  1.62s/it]

Batch 20/425 | Loss: 6.2139


Epoch 1/10:   7%|▋         | 31/425 [00:51<10:15,  1.56s/it]

Batch 30/425 | Loss: 6.2152


Epoch 1/10:  10%|▉         | 41/425 [01:07<10:13,  1.60s/it]

Batch 40/425 | Loss: 6.2113


Epoch 1/10:  12%|█▏        | 51/425 [01:23<09:51,  1.58s/it]

Batch 50/425 | Loss: 6.2144


Epoch 1/10:  14%|█▍        | 61/425 [01:39<09:25,  1.55s/it]

Batch 60/425 | Loss: 6.2185


Epoch 1/10:  17%|█▋        | 71/425 [01:55<09:18,  1.58s/it]

Batch 70/425 | Loss: 6.1955


Epoch 1/10:  19%|█▉        | 81/425 [02:11<09:00,  1.57s/it]

Batch 80/425 | Loss: 6.2055


Epoch 1/10:  21%|██▏       | 91/425 [02:27<08:40,  1.56s/it]

Batch 90/425 | Loss: 6.1671


Epoch 1/10:  24%|██▍       | 101/425 [02:43<08:25,  1.56s/it]

Batch 100/425 | Loss: 6.0907


Epoch 1/10:  26%|██▌       | 111/425 [02:58<08:02,  1.54s/it]

Batch 110/425 | Loss: 5.3363


Epoch 1/10:  28%|██▊       | 121/425 [03:13<07:50,  1.55s/it]

Batch 120/425 | Loss: 4.3146


Epoch 1/10:  31%|███       | 131/425 [03:29<07:46,  1.59s/it]

Batch 130/425 | Loss: 3.3443


Epoch 1/10:  33%|███▎      | 141/425 [03:45<07:16,  1.54s/it]

Batch 140/425 | Loss: 2.4574


Epoch 1/10:  36%|███▌      | 151/425 [04:00<07:02,  1.54s/it]

Batch 150/425 | Loss: 2.2541


Epoch 1/10:  38%|███▊      | 161/425 [04:16<06:52,  1.56s/it]

Batch 160/425 | Loss: 1.3042


Epoch 1/10:  40%|████      | 171/425 [04:32<06:35,  1.56s/it]

Batch 170/425 | Loss: 1.6777


Epoch 1/10:  43%|████▎     | 181/425 [04:47<06:16,  1.54s/it]

Batch 180/425 | Loss: 1.7001


Epoch 1/10:  45%|████▍     | 191/425 [05:03<06:22,  1.64s/it]

Batch 190/425 | Loss: 1.0977


Epoch 1/10:  47%|████▋     | 201/425 [05:19<05:52,  1.57s/it]

Batch 200/425 | Loss: 1.1058


Epoch 1/10:  50%|████▉     | 211/425 [05:35<05:35,  1.57s/it]

Batch 210/425 | Loss: 1.1960


Epoch 1/10:  52%|█████▏    | 221/425 [05:50<05:18,  1.56s/it]

Batch 220/425 | Loss: 1.3685


Epoch 1/10:  54%|█████▍    | 231/425 [06:07<05:20,  1.65s/it]

Batch 230/425 | Loss: 0.8050


Epoch 1/10:  57%|█████▋    | 241/425 [06:23<04:58,  1.62s/it]

Batch 240/425 | Loss: 1.1502


Epoch 1/10:  59%|█████▉    | 251/425 [06:39<04:37,  1.60s/it]

Batch 250/425 | Loss: 0.9130


Epoch 1/10:  61%|██████▏   | 261/425 [06:55<04:18,  1.57s/it]

Batch 260/425 | Loss: 0.9072


Epoch 1/10:  64%|██████▍   | 271/425 [07:11<04:03,  1.58s/it]

Batch 270/425 | Loss: 0.5010


Epoch 1/10:  66%|██████▌   | 281/425 [07:27<03:52,  1.62s/it]

Batch 280/425 | Loss: 0.6762


Epoch 1/10:  68%|██████▊   | 291/425 [07:43<03:31,  1.58s/it]

Batch 290/425 | Loss: 0.4218


Epoch 1/10:  71%|███████   | 301/425 [07:59<03:19,  1.61s/it]

Batch 300/425 | Loss: 0.8309


Epoch 1/10:  73%|███████▎  | 311/425 [08:15<03:11,  1.68s/it]

Batch 310/425 | Loss: 0.4315


Epoch 1/10:  76%|███████▌  | 321/425 [08:31<02:48,  1.62s/it]

Batch 320/425 | Loss: 0.8241


Epoch 1/10:  78%|███████▊  | 331/425 [08:47<02:28,  1.58s/it]

Batch 330/425 | Loss: 0.5236


Epoch 1/10:  80%|████████  | 341/425 [09:03<02:16,  1.63s/it]

Batch 340/425 | Loss: 0.6626


Epoch 1/10:  83%|████████▎ | 351/425 [09:19<01:55,  1.56s/it]

Batch 350/425 | Loss: 0.7484


Epoch 1/10:  85%|████████▍ | 361/425 [09:36<01:43,  1.61s/it]

Batch 360/425 | Loss: 0.4514


Epoch 1/10:  87%|████████▋ | 371/425 [09:51<01:23,  1.55s/it]

Batch 370/425 | Loss: 0.7090


Epoch 1/10:  90%|████████▉ | 381/425 [10:08<01:12,  1.65s/it]

Batch 380/425 | Loss: 0.5467


Epoch 1/10:  92%|█████████▏| 391/425 [10:23<00:54,  1.61s/it]

Batch 390/425 | Loss: 0.3966


Epoch 1/10:  94%|█████████▍| 401/425 [10:39<00:36,  1.53s/it]

Batch 400/425 | Loss: 0.5246


Epoch 1/10:  97%|█████████▋| 411/425 [10:54<00:21,  1.52s/it]

Batch 410/425 | Loss: 0.6662


Epoch 1/10:  99%|█████████▉| 421/425 [11:10<00:06,  1.54s/it]

Batch 420/425 | Loss: 0.7991


Epoch 1/10: 100%|██████████| 425/425 [11:15<00:00,  1.59s/it]


Epoch 1/10 | Avg Loss: 2.4662 | Accuracy: 56.65%



Epoch 2/10:   0%|          | 1/425 [00:00<06:30,  1.09it/s]

Batch 0/425 | Loss: 0.6412


Epoch 2/10:   3%|▎         | 11/425 [00:10<06:16,  1.10it/s]

Batch 10/425 | Loss: 0.8243


Epoch 2/10:   5%|▍         | 21/425 [00:19<06:11,  1.09it/s]

Batch 20/425 | Loss: 0.4841


Epoch 2/10:   7%|▋         | 31/425 [00:28<05:55,  1.11it/s]

Batch 30/425 | Loss: 0.6121


Epoch 2/10:  10%|▉         | 41/425 [00:37<05:47,  1.10it/s]

Batch 40/425 | Loss: 0.4789


Epoch 2/10:  12%|█▏        | 51/425 [00:46<06:04,  1.03it/s]

Batch 50/425 | Loss: 0.6560


Epoch 2/10:  14%|█▍        | 61/425 [00:56<05:34,  1.09it/s]

Batch 60/425 | Loss: 0.4521


Epoch 2/10:  17%|█▋        | 71/425 [01:05<05:25,  1.09it/s]

Batch 70/425 | Loss: 0.5309


Epoch 2/10:  19%|█▉        | 81/425 [01:14<05:19,  1.08it/s]

Batch 80/425 | Loss: 0.9151


Epoch 2/10:  21%|██▏       | 91/425 [01:23<05:05,  1.09it/s]

Batch 90/425 | Loss: 0.5952


Epoch 2/10:  24%|██▍       | 101/425 [01:33<05:05,  1.06it/s]

Batch 100/425 | Loss: 0.4814


Epoch 2/10:  26%|██▌       | 111/425 [01:42<04:54,  1.07it/s]

Batch 110/425 | Loss: 0.4684


Epoch 2/10:  28%|██▊       | 121/425 [01:52<04:53,  1.04it/s]

Batch 120/425 | Loss: 0.7042


Epoch 2/10:  31%|███       | 131/425 [02:01<04:31,  1.08it/s]

Batch 130/425 | Loss: 0.4963


Epoch 2/10:  33%|███▎      | 141/425 [02:10<04:35,  1.03it/s]

Batch 140/425 | Loss: 0.5239


Epoch 2/10:  36%|███▌      | 151/425 [02:20<04:22,  1.04it/s]

Batch 150/425 | Loss: 0.2729


Epoch 2/10:  38%|███▊      | 161/425 [02:29<04:12,  1.05it/s]

Batch 160/425 | Loss: 0.3246


Epoch 2/10:  40%|████      | 171/425 [02:38<03:59,  1.06it/s]

Batch 170/425 | Loss: 0.5169


Epoch 2/10:  43%|████▎     | 181/425 [02:48<04:09,  1.02s/it]

Batch 180/425 | Loss: 0.4828


Epoch 2/10:  45%|████▍     | 191/425 [02:59<03:57,  1.02s/it]

Batch 190/425 | Loss: 0.5990


Epoch 2/10:  47%|████▋     | 201/425 [03:08<03:41,  1.01it/s]

Batch 200/425 | Loss: 0.3778


Epoch 2/10:  50%|████▉     | 211/425 [03:18<03:29,  1.02it/s]

Batch 210/425 | Loss: 0.6320


Epoch 2/10:  52%|█████▏    | 221/425 [03:28<03:12,  1.06it/s]

Batch 220/425 | Loss: 0.4030


Epoch 2/10:  54%|█████▍    | 231/425 [03:37<03:02,  1.07it/s]

Batch 230/425 | Loss: 0.5704


Epoch 2/10:  57%|█████▋    | 241/425 [03:47<02:54,  1.06it/s]

Batch 240/425 | Loss: 0.7675


Epoch 2/10:  59%|█████▉    | 251/425 [03:56<02:44,  1.06it/s]

Batch 250/425 | Loss: 0.3686


Epoch 2/10:  61%|██████▏   | 261/425 [04:06<02:35,  1.05it/s]

Batch 260/425 | Loss: 0.7081


Epoch 2/10:  64%|██████▍   | 271/425 [04:15<02:25,  1.06it/s]

Batch 270/425 | Loss: 0.2967


Epoch 2/10:  66%|██████▌   | 281/425 [04:25<02:16,  1.05it/s]

Batch 280/425 | Loss: 0.3726


Epoch 2/10:  68%|██████▊   | 291/425 [04:34<02:06,  1.06it/s]

Batch 290/425 | Loss: 0.1608


Epoch 2/10:  71%|███████   | 301/425 [04:44<01:58,  1.05it/s]

Batch 300/425 | Loss: 0.4174


Epoch 2/10:  73%|███████▎  | 311/425 [04:53<01:48,  1.05it/s]

Batch 310/425 | Loss: 0.2872


Epoch 2/10:  76%|███████▌  | 321/425 [05:03<01:39,  1.05it/s]

Batch 320/425 | Loss: 0.2353


Epoch 2/10:  78%|███████▊  | 331/425 [05:13<01:29,  1.05it/s]

Batch 330/425 | Loss: 0.4991


Epoch 2/10:  80%|████████  | 341/425 [05:22<01:19,  1.05it/s]

Batch 340/425 | Loss: 0.1915


Epoch 2/10:  83%|████████▎ | 351/425 [05:32<01:10,  1.05it/s]

Batch 350/425 | Loss: 0.4662


Epoch 2/10:  85%|████████▍ | 361/425 [05:41<01:02,  1.02it/s]

Batch 360/425 | Loss: 0.3593


Epoch 2/10:  87%|████████▋ | 371/425 [05:51<00:52,  1.04it/s]

Batch 370/425 | Loss: 0.3632


Epoch 2/10:  90%|████████▉ | 381/425 [06:01<00:43,  1.02it/s]

Batch 380/425 | Loss: 0.5999


Epoch 2/10:  92%|█████████▏| 391/425 [06:10<00:32,  1.04it/s]

Batch 390/425 | Loss: 0.4609


Epoch 2/10:  94%|█████████▍| 401/425 [06:20<00:23,  1.03it/s]

Batch 400/425 | Loss: 0.3584


Epoch 2/10:  97%|█████████▋| 411/425 [06:30<00:13,  1.03it/s]

Batch 410/425 | Loss: 0.3033


Epoch 2/10:  99%|█████████▉| 421/425 [06:40<00:03,  1.02it/s]

Batch 420/425 | Loss: 0.3306


Epoch 2/10: 100%|██████████| 425/425 [06:43<00:00,  1.05it/s]


Epoch 2/10 | Avg Loss: 0.4477 | Accuracy: 91.19%



Epoch 3/10:   0%|          | 1/425 [00:00<06:24,  1.10it/s]

Batch 0/425 | Loss: 0.4072


Epoch 3/10:   3%|▎         | 11/425 [00:09<06:11,  1.11it/s]

Batch 10/425 | Loss: 0.4801


Epoch 3/10:   5%|▍         | 21/425 [00:19<06:18,  1.07it/s]

Batch 20/425 | Loss: 0.3237


Epoch 3/10:   7%|▋         | 31/425 [00:28<05:53,  1.12it/s]

Batch 30/425 | Loss: 0.2477


Epoch 3/10:  10%|▉         | 41/425 [00:37<05:44,  1.11it/s]

Batch 40/425 | Loss: 0.3052


Epoch 3/10:  12%|█▏        | 51/425 [00:46<05:31,  1.13it/s]

Batch 50/425 | Loss: 0.3225


Epoch 3/10:  14%|█▍        | 61/425 [00:55<05:30,  1.10it/s]

Batch 60/425 | Loss: 0.2285


Epoch 3/10:  17%|█▋        | 71/425 [01:04<05:18,  1.11it/s]

Batch 70/425 | Loss: 0.3598


Epoch 3/10:  19%|█▉        | 81/425 [01:13<05:05,  1.12it/s]

Batch 80/425 | Loss: 0.1078


Epoch 3/10:  21%|██▏       | 91/425 [01:22<04:57,  1.12it/s]

Batch 90/425 | Loss: 0.4766


Epoch 3/10:  24%|██▍       | 101/425 [01:31<04:56,  1.09it/s]

Batch 100/425 | Loss: 0.3425


Epoch 3/10:  26%|██▌       | 111/425 [01:40<04:43,  1.11it/s]

Batch 110/425 | Loss: 0.2354


Epoch 3/10:  28%|██▊       | 121/425 [01:49<04:36,  1.10it/s]

Batch 120/425 | Loss: 0.4355


Epoch 3/10:  31%|███       | 131/425 [01:58<04:23,  1.11it/s]

Batch 130/425 | Loss: 0.3246


Epoch 3/10:  33%|███▎      | 141/425 [02:07<04:15,  1.11it/s]

Batch 140/425 | Loss: 0.1583


Epoch 3/10:  36%|███▌      | 151/425 [02:16<04:09,  1.10it/s]

Batch 150/425 | Loss: 0.4676


Epoch 3/10:  38%|███▊      | 161/425 [02:25<03:58,  1.11it/s]

Batch 160/425 | Loss: 0.4242


Epoch 3/10:  40%|████      | 171/425 [02:35<03:54,  1.08it/s]

Batch 170/425 | Loss: 0.5282


Epoch 3/10:  43%|████▎     | 181/425 [02:44<03:40,  1.11it/s]

Batch 180/425 | Loss: 0.4110


Epoch 3/10:  45%|████▍     | 191/425 [02:53<03:29,  1.12it/s]

Batch 190/425 | Loss: 0.2745


Epoch 3/10:  47%|████▋     | 201/425 [03:02<03:24,  1.10it/s]

Batch 200/425 | Loss: 0.3097


Epoch 3/10:  50%|████▉     | 211/425 [03:11<03:14,  1.10it/s]

Batch 210/425 | Loss: 0.2842


Epoch 3/10:  52%|█████▏    | 221/425 [03:20<03:05,  1.10it/s]

Batch 220/425 | Loss: 0.3392


Epoch 3/10:  54%|█████▍    | 231/425 [03:29<03:00,  1.07it/s]

Batch 230/425 | Loss: 0.2856


Epoch 3/10:  57%|█████▋    | 241/425 [03:38<02:48,  1.09it/s]

Batch 240/425 | Loss: 0.3470


Epoch 3/10:  59%|█████▉    | 251/425 [03:48<02:40,  1.08it/s]

Batch 250/425 | Loss: 0.3254


Epoch 3/10:  61%|██████▏   | 261/425 [03:57<02:33,  1.07it/s]

Batch 260/425 | Loss: 0.3155


Epoch 3/10:  64%|██████▍   | 271/425 [04:06<02:22,  1.08it/s]

Batch 270/425 | Loss: 0.1983


Epoch 3/10:  66%|██████▌   | 281/425 [04:15<02:13,  1.08it/s]

Batch 280/425 | Loss: 0.2747


Epoch 3/10:  68%|██████▊   | 291/425 [04:25<02:09,  1.04it/s]

Batch 290/425 | Loss: 0.1690


Epoch 3/10:  71%|███████   | 301/425 [04:34<01:55,  1.07it/s]

Batch 300/425 | Loss: 0.3022


Epoch 3/10:  73%|███████▎  | 311/425 [04:43<01:46,  1.07it/s]

Batch 310/425 | Loss: 0.2004


Epoch 3/10:  76%|███████▌  | 321/425 [04:53<01:36,  1.08it/s]

Batch 320/425 | Loss: 0.1971


Epoch 3/10:  78%|███████▊  | 331/425 [05:02<01:26,  1.09it/s]

Batch 330/425 | Loss: 0.1638


Epoch 3/10:  80%|████████  | 341/425 [05:11<01:17,  1.08it/s]

Batch 340/425 | Loss: 0.2735


Epoch 3/10:  83%|████████▎ | 351/425 [05:21<01:09,  1.06it/s]

Batch 350/425 | Loss: 0.3407


Epoch 3/10:  85%|████████▍ | 361/425 [05:30<00:59,  1.07it/s]

Batch 360/425 | Loss: 0.2976


Epoch 3/10:  87%|████████▋ | 371/425 [05:40<00:51,  1.06it/s]

Batch 370/425 | Loss: 0.2929


Epoch 3/10:  90%|████████▉ | 381/425 [05:49<00:41,  1.07it/s]

Batch 380/425 | Loss: 0.4087


Epoch 3/10:  92%|█████████▏| 391/425 [05:59<00:31,  1.07it/s]

Batch 390/425 | Loss: 0.1326


Epoch 3/10:  94%|█████████▍| 401/425 [06:08<00:22,  1.05it/s]

Batch 400/425 | Loss: 0.5587


Epoch 3/10:  97%|█████████▋| 411/425 [06:17<00:13,  1.06it/s]

Batch 410/425 | Loss: 0.3005


Epoch 3/10:  99%|█████████▉| 421/425 [06:27<00:03,  1.04it/s]

Batch 420/425 | Loss: 0.3713


Epoch 3/10: 100%|██████████| 425/425 [06:30<00:00,  1.09it/s]


Epoch 3/10 | Avg Loss: 0.3489 | Accuracy: 93.16%



Epoch 4/10:   0%|          | 1/425 [00:00<06:16,  1.13it/s]

Batch 0/425 | Loss: 0.1915


Epoch 4/10:   3%|▎         | 11/425 [00:09<06:15,  1.10it/s]

Batch 10/425 | Loss: 0.2208


Epoch 4/10:   5%|▍         | 21/425 [00:18<06:04,  1.11it/s]

Batch 20/425 | Loss: 0.4624


Epoch 4/10:   7%|▋         | 31/425 [00:28<06:02,  1.09it/s]

Batch 30/425 | Loss: 0.2291


Epoch 4/10:  10%|▉         | 41/425 [00:37<06:00,  1.06it/s]

Batch 40/425 | Loss: 0.5203


Epoch 4/10:  12%|█▏        | 51/425 [00:46<05:41,  1.10it/s]

Batch 50/425 | Loss: 0.3202


Epoch 4/10:  14%|█▍        | 61/425 [00:55<05:43,  1.06it/s]

Batch 60/425 | Loss: 0.3877


Epoch 4/10:  17%|█▋        | 71/425 [01:04<05:15,  1.12it/s]

Batch 70/425 | Loss: 0.0760


Epoch 4/10:  19%|█▉        | 81/425 [01:13<05:06,  1.12it/s]

Batch 80/425 | Loss: 0.0758


Epoch 4/10:  21%|██▏       | 91/425 [01:22<04:59,  1.12it/s]

Batch 90/425 | Loss: 0.0996


Epoch 4/10:  24%|██▍       | 101/425 [01:31<04:52,  1.11it/s]

Batch 100/425 | Loss: 0.4467


Epoch 4/10:  26%|██▌       | 111/425 [01:40<04:40,  1.12it/s]

Batch 110/425 | Loss: 0.2082


Epoch 4/10:  28%|██▊       | 121/425 [01:49<04:33,  1.11it/s]

Batch 120/425 | Loss: 0.2855


Epoch 4/10:  31%|███       | 131/425 [01:59<04:35,  1.07it/s]

Batch 130/425 | Loss: 0.1924


Epoch 4/10:  33%|███▎      | 141/425 [02:08<04:20,  1.09it/s]

Batch 140/425 | Loss: 0.4043


Epoch 4/10:  36%|███▌      | 151/425 [02:17<04:05,  1.11it/s]

Batch 150/425 | Loss: 0.4527


Epoch 4/10:  38%|███▊      | 161/425 [02:26<04:00,  1.10it/s]

Batch 160/425 | Loss: 0.4241


Epoch 4/10:  40%|████      | 171/425 [02:35<03:50,  1.10it/s]

Batch 170/425 | Loss: 0.4576


Epoch 4/10:  43%|████▎     | 181/425 [02:44<03:38,  1.12it/s]

Batch 180/425 | Loss: 0.3524


Epoch 4/10:  45%|████▍     | 191/425 [02:53<03:28,  1.12it/s]

Batch 190/425 | Loss: 0.4231


Epoch 4/10:  47%|████▋     | 201/425 [03:02<03:25,  1.09it/s]

Batch 200/425 | Loss: 0.5453


Epoch 4/10:  50%|████▉     | 211/425 [03:12<03:26,  1.03it/s]

Batch 210/425 | Loss: 0.2127


Epoch 4/10:  52%|█████▏    | 221/425 [03:21<03:07,  1.09it/s]

Batch 220/425 | Loss: 0.2801


Epoch 4/10:  54%|█████▍    | 231/425 [03:30<03:04,  1.05it/s]

Batch 230/425 | Loss: 0.3305


Epoch 4/10:  57%|█████▋    | 241/425 [03:39<02:49,  1.09it/s]

Batch 240/425 | Loss: 0.3878


Epoch 4/10:  59%|█████▉    | 251/425 [03:49<02:44,  1.06it/s]

Batch 250/425 | Loss: 0.5699


Epoch 4/10:  61%|██████▏   | 261/425 [03:58<02:31,  1.09it/s]

Batch 260/425 | Loss: 0.2742


Epoch 4/10:  64%|██████▍   | 271/425 [04:07<02:23,  1.08it/s]

Batch 270/425 | Loss: 0.4518


Epoch 4/10:  66%|██████▌   | 281/425 [04:16<02:10,  1.10it/s]

Batch 280/425 | Loss: 0.2997


Epoch 4/10:  68%|██████▊   | 291/425 [04:25<02:05,  1.07it/s]

Batch 290/425 | Loss: 0.2781


Epoch 4/10:  71%|███████   | 301/425 [04:35<01:55,  1.08it/s]

Batch 300/425 | Loss: 0.2042


Epoch 4/10:  73%|███████▎  | 311/425 [04:44<01:46,  1.07it/s]

Batch 310/425 | Loss: 0.2565


Epoch 4/10:  76%|███████▌  | 321/425 [04:54<01:38,  1.06it/s]

Batch 320/425 | Loss: 0.2527


Epoch 4/10:  78%|███████▊  | 331/425 [05:03<01:27,  1.08it/s]

Batch 330/425 | Loss: 0.2920


Epoch 4/10:  80%|████████  | 341/425 [05:12<01:17,  1.08it/s]

Batch 340/425 | Loss: 0.2494


Epoch 4/10:  83%|████████▎ | 351/425 [05:21<01:08,  1.08it/s]

Batch 350/425 | Loss: 0.1608


Epoch 4/10:  85%|████████▍ | 361/425 [05:31<01:00,  1.06it/s]

Batch 360/425 | Loss: 0.4349


Epoch 4/10:  87%|████████▋ | 371/425 [05:41<00:51,  1.04it/s]

Batch 370/425 | Loss: 0.1420


Epoch 4/10:  90%|████████▉ | 381/425 [05:50<00:42,  1.04it/s]

Batch 380/425 | Loss: 0.1578


Epoch 4/10:  92%|█████████▏| 391/425 [06:00<00:32,  1.04it/s]

Batch 390/425 | Loss: 0.3574


Epoch 4/10:  94%|█████████▍| 401/425 [06:10<00:23,  1.03it/s]

Batch 400/425 | Loss: 0.2293


Epoch 4/10:  97%|█████████▋| 411/425 [06:19<00:13,  1.06it/s]

Batch 410/425 | Loss: 0.2782


Epoch 4/10:  99%|█████████▉| 421/425 [06:29<00:03,  1.03it/s]

Batch 420/425 | Loss: 0.1787


Epoch 4/10: 100%|██████████| 425/425 [06:32<00:00,  1.08it/s]


Epoch 4/10 | Avg Loss: 0.3149 | Accuracy: 93.73%



Epoch 5/10:   0%|          | 1/425 [00:00<06:26,  1.10it/s]

Batch 0/425 | Loss: 0.3890


Epoch 5/10:   3%|▎         | 11/425 [00:10<06:18,  1.09it/s]

Batch 10/425 | Loss: 0.2176


Epoch 5/10:   5%|▍         | 21/425 [00:19<06:15,  1.08it/s]

Batch 20/425 | Loss: 0.2369


Epoch 5/10:   7%|▋         | 31/425 [00:28<05:57,  1.10it/s]

Batch 30/425 | Loss: 0.2161


Epoch 5/10:  10%|▉         | 41/425 [00:37<05:49,  1.10it/s]

Batch 40/425 | Loss: 0.3009


Epoch 5/10:  12%|█▏        | 51/425 [00:46<05:38,  1.11it/s]

Batch 50/425 | Loss: 0.2626


Epoch 5/10:  14%|█▍        | 61/425 [00:55<05:35,  1.08it/s]

Batch 60/425 | Loss: 0.2338


Epoch 5/10:  17%|█▋        | 71/425 [01:05<05:30,  1.07it/s]

Batch 70/425 | Loss: 0.2812


Epoch 5/10:  19%|█▉        | 81/425 [01:14<05:19,  1.08it/s]

Batch 80/425 | Loss: 0.3402


Epoch 5/10:  21%|██▏       | 91/425 [01:23<05:10,  1.08it/s]

Batch 90/425 | Loss: 0.1918


Epoch 5/10:  24%|██▍       | 101/425 [01:32<04:53,  1.11it/s]

Batch 100/425 | Loss: 0.5590


Epoch 5/10:  26%|██▌       | 111/425 [01:41<04:43,  1.11it/s]

Batch 110/425 | Loss: 0.4392


Epoch 5/10:  28%|██▊       | 121/425 [01:50<04:31,  1.12it/s]

Batch 120/425 | Loss: 0.2921


Epoch 5/10:  31%|███       | 131/425 [02:00<04:29,  1.09it/s]

Batch 130/425 | Loss: 0.2480


Epoch 5/10:  33%|███▎      | 141/425 [02:09<04:14,  1.12it/s]

Batch 140/425 | Loss: 0.2101


Epoch 5/10:  36%|███▌      | 151/425 [02:18<04:09,  1.10it/s]

Batch 150/425 | Loss: 0.3050


Epoch 5/10:  38%|███▊      | 161/425 [02:27<04:04,  1.08it/s]

Batch 160/425 | Loss: 0.2834


Epoch 5/10:  40%|████      | 171/425 [02:36<03:49,  1.11it/s]

Batch 170/425 | Loss: 0.3669


Epoch 5/10:  43%|████▎     | 181/425 [02:45<03:38,  1.12it/s]

Batch 180/425 | Loss: 0.4906


Epoch 5/10:  45%|████▍     | 191/425 [02:54<03:32,  1.10it/s]

Batch 190/425 | Loss: 0.3138


Epoch 5/10:  47%|████▋     | 201/425 [03:03<03:36,  1.03it/s]

Batch 200/425 | Loss: 0.3727


Epoch 5/10:  50%|████▉     | 211/425 [03:13<03:19,  1.07it/s]

Batch 210/425 | Loss: 0.3587


Epoch 5/10:  52%|█████▏    | 221/425 [03:22<03:10,  1.07it/s]

Batch 220/425 | Loss: 0.3582


Epoch 5/10:  54%|█████▍    | 231/425 [03:31<02:58,  1.09it/s]

Batch 230/425 | Loss: 0.2733


Epoch 5/10:  57%|█████▋    | 241/425 [03:41<02:52,  1.07it/s]

Batch 240/425 | Loss: 0.3250


Epoch 5/10:  59%|█████▉    | 251/425 [03:50<02:40,  1.09it/s]

Batch 250/425 | Loss: 0.1399


Epoch 5/10:  61%|██████▏   | 261/425 [03:59<02:31,  1.08it/s]

Batch 260/425 | Loss: 0.1887


Epoch 5/10:  64%|██████▍   | 271/425 [04:09<02:24,  1.07it/s]

Batch 270/425 | Loss: 0.1289


Epoch 5/10:  66%|██████▌   | 281/425 [04:18<02:17,  1.05it/s]

Batch 280/425 | Loss: 0.2664


Epoch 5/10:  68%|██████▊   | 291/425 [04:28<02:05,  1.07it/s]

Batch 290/425 | Loss: 0.4193


Epoch 5/10:  71%|███████   | 301/425 [04:37<01:58,  1.05it/s]

Batch 300/425 | Loss: 0.2147


Epoch 5/10:  73%|███████▎  | 311/425 [04:46<01:47,  1.06it/s]

Batch 310/425 | Loss: 0.4863


Epoch 5/10:  76%|███████▌  | 321/425 [04:56<01:39,  1.04it/s]

Batch 320/425 | Loss: 0.2260


Epoch 5/10:  78%|███████▊  | 331/425 [05:05<01:28,  1.06it/s]

Batch 330/425 | Loss: 0.2032


Epoch 5/10:  80%|████████  | 341/425 [05:15<01:18,  1.07it/s]

Batch 340/425 | Loss: 0.1883


Epoch 5/10:  83%|████████▎ | 351/425 [05:24<01:10,  1.06it/s]

Batch 350/425 | Loss: 0.4539


Epoch 5/10:  85%|████████▍ | 361/425 [05:34<01:00,  1.05it/s]

Batch 360/425 | Loss: 0.4983


Epoch 5/10:  87%|████████▋ | 371/425 [05:44<00:54,  1.00s/it]

Batch 370/425 | Loss: 0.7344


Epoch 5/10:  90%|████████▉ | 381/425 [05:54<00:42,  1.03it/s]

Batch 380/425 | Loss: 0.3209


Epoch 5/10:  92%|█████████▏| 391/425 [06:03<00:32,  1.04it/s]

Batch 390/425 | Loss: 0.1890


Epoch 5/10:  94%|█████████▍| 401/425 [06:13<00:24,  1.03s/it]

Batch 400/425 | Loss: 0.1983


Epoch 5/10:  97%|█████████▋| 411/425 [06:23<00:13,  1.03it/s]

Batch 410/425 | Loss: 0.1836


Epoch 5/10:  99%|█████████▉| 421/425 [06:32<00:03,  1.05it/s]

Batch 420/425 | Loss: 0.3100


Epoch 5/10: 100%|██████████| 425/425 [06:36<00:00,  1.07it/s]


Epoch 5/10 | Avg Loss: 0.2909 | Accuracy: 94.13%



Epoch 6/10:   0%|          | 1/425 [00:00<06:15,  1.13it/s]

Batch 0/425 | Loss: 0.4167


Epoch 6/10:   3%|▎         | 11/425 [00:10<06:18,  1.09it/s]

Batch 10/425 | Loss: 0.1494


Epoch 6/10:   5%|▍         | 21/425 [00:19<06:04,  1.11it/s]

Batch 20/425 | Loss: 0.1458


Epoch 6/10:   7%|▋         | 31/425 [00:28<05:53,  1.11it/s]

Batch 30/425 | Loss: 0.3109


Epoch 6/10:  10%|▉         | 41/425 [00:37<05:48,  1.10it/s]

Batch 40/425 | Loss: 0.1218


Epoch 6/10:  12%|█▏        | 51/425 [00:46<05:40,  1.10it/s]

Batch 50/425 | Loss: 0.2402


Epoch 6/10:  14%|█▍        | 61/425 [00:55<05:30,  1.10it/s]

Batch 60/425 | Loss: 0.4897


Epoch 6/10:  17%|█▋        | 71/425 [01:04<05:17,  1.11it/s]

Batch 70/425 | Loss: 0.4926


Epoch 6/10:  19%|█▉        | 81/425 [01:13<05:20,  1.07it/s]

Batch 80/425 | Loss: 0.2861


Epoch 6/10:  21%|██▏       | 91/425 [01:22<05:07,  1.09it/s]

Batch 90/425 | Loss: 0.3096


Epoch 6/10:  24%|██▍       | 101/425 [01:31<04:53,  1.10it/s]

Batch 100/425 | Loss: 0.2390


Epoch 6/10:  26%|██▌       | 111/425 [01:41<04:50,  1.08it/s]

Batch 110/425 | Loss: 0.2880


Epoch 6/10:  28%|██▊       | 121/425 [01:50<04:37,  1.09it/s]

Batch 120/425 | Loss: 0.2613


Epoch 6/10:  31%|███       | 131/425 [01:59<04:25,  1.11it/s]

Batch 130/425 | Loss: 0.1770


Epoch 6/10:  33%|███▎      | 141/425 [02:08<04:22,  1.08it/s]

Batch 140/425 | Loss: 0.3386


Epoch 6/10:  36%|███▌      | 151/425 [02:17<04:14,  1.08it/s]

Batch 150/425 | Loss: 0.1913


Epoch 6/10:  38%|███▊      | 161/425 [02:27<04:02,  1.09it/s]

Batch 160/425 | Loss: 0.2730


Epoch 6/10:  40%|████      | 171/425 [02:36<03:53,  1.09it/s]

Batch 170/425 | Loss: 0.2847


Epoch 6/10:  43%|████▎     | 181/425 [02:45<03:45,  1.08it/s]

Batch 180/425 | Loss: 0.4724


Epoch 6/10:  45%|████▍     | 191/425 [02:54<03:32,  1.10it/s]

Batch 190/425 | Loss: 0.2394


Epoch 6/10:  47%|████▋     | 201/425 [03:03<03:21,  1.11it/s]

Batch 200/425 | Loss: 0.3932


Epoch 6/10:  50%|████▉     | 211/425 [03:12<03:13,  1.10it/s]

Batch 210/425 | Loss: 0.3994


Epoch 6/10:  52%|█████▏    | 221/425 [03:21<03:07,  1.09it/s]

Batch 220/425 | Loss: 0.2830


Epoch 6/10:  54%|█████▍    | 231/425 [03:31<02:59,  1.08it/s]

Batch 230/425 | Loss: 0.2906


Epoch 6/10:  57%|█████▋    | 241/425 [03:40<02:48,  1.09it/s]

Batch 240/425 | Loss: 0.3771


Epoch 6/10:  59%|█████▉    | 251/425 [03:49<02:42,  1.07it/s]

Batch 250/425 | Loss: 0.1492


Epoch 6/10:  61%|██████▏   | 261/425 [03:58<02:30,  1.09it/s]

Batch 260/425 | Loss: 0.3164


Epoch 6/10:  64%|██████▍   | 271/425 [04:08<02:21,  1.09it/s]

Batch 270/425 | Loss: 0.2627


Epoch 6/10:  66%|██████▌   | 281/425 [04:17<02:17,  1.05it/s]

Batch 280/425 | Loss: 0.2230


Epoch 6/10:  68%|██████▊   | 291/425 [04:26<02:05,  1.07it/s]

Batch 290/425 | Loss: 0.4265


Epoch 6/10:  71%|███████   | 301/425 [04:36<01:59,  1.04it/s]

Batch 300/425 | Loss: 0.1589


Epoch 6/10:  73%|███████▎  | 311/425 [04:45<01:49,  1.04it/s]

Batch 310/425 | Loss: 0.2347


Epoch 6/10:  76%|███████▌  | 321/425 [04:55<01:38,  1.05it/s]

Batch 320/425 | Loss: 0.0953


Epoch 6/10:  78%|███████▊  | 331/425 [05:04<01:27,  1.08it/s]

Batch 330/425 | Loss: 0.2456


Epoch 6/10:  80%|████████  | 341/425 [05:14<01:20,  1.04it/s]

Batch 340/425 | Loss: 0.1369


Epoch 6/10:  83%|████████▎ | 351/425 [05:24<01:12,  1.03it/s]

Batch 350/425 | Loss: 0.2689


Epoch 6/10:  85%|████████▍ | 361/425 [05:33<01:00,  1.06it/s]

Batch 360/425 | Loss: 0.1850


Epoch 6/10:  87%|████████▋ | 371/425 [05:43<00:51,  1.06it/s]

Batch 370/425 | Loss: 0.3331


Epoch 6/10:  90%|████████▉ | 381/425 [05:52<00:41,  1.05it/s]

Batch 380/425 | Loss: 0.2815


Epoch 6/10:  92%|█████████▏| 391/425 [06:02<00:32,  1.04it/s]

Batch 390/425 | Loss: 0.2539


Epoch 6/10:  94%|█████████▍| 401/425 [06:11<00:22,  1.06it/s]

Batch 400/425 | Loss: 0.2776


Epoch 6/10:  97%|█████████▋| 411/425 [06:21<00:14,  1.02s/it]

Batch 410/425 | Loss: 0.3029


Epoch 6/10:  99%|█████████▉| 421/425 [06:31<00:04,  1.02s/it]

Batch 420/425 | Loss: 0.2101


Epoch 6/10: 100%|██████████| 425/425 [06:34<00:00,  1.08it/s]


Epoch 6/10 | Avg Loss: 0.2744 | Accuracy: 94.34%



Epoch 7/10:   0%|          | 1/425 [00:00<06:23,  1.11it/s]

Batch 0/425 | Loss: 0.2699


Epoch 7/10:   3%|▎         | 11/425 [00:10<06:38,  1.04it/s]

Batch 10/425 | Loss: 0.2255


Epoch 7/10:   5%|▍         | 21/425 [00:19<06:06,  1.10it/s]

Batch 20/425 | Loss: 0.2628


Epoch 7/10:   7%|▋         | 31/425 [00:28<05:53,  1.11it/s]

Batch 30/425 | Loss: 0.2993


Epoch 7/10:  10%|▉         | 41/425 [00:37<05:41,  1.12it/s]

Batch 40/425 | Loss: 0.4200


Epoch 7/10:  12%|█▏        | 51/425 [00:46<05:37,  1.11it/s]

Batch 50/425 | Loss: 0.1469


Epoch 7/10:  14%|█▍        | 61/425 [00:55<05:35,  1.09it/s]

Batch 60/425 | Loss: 0.1777


Epoch 7/10:  17%|█▋        | 71/425 [01:04<05:21,  1.10it/s]

Batch 70/425 | Loss: 0.1856


Epoch 7/10:  19%|█▉        | 81/425 [01:13<05:09,  1.11it/s]

Batch 80/425 | Loss: 0.1827


Epoch 7/10:  21%|██▏       | 91/425 [01:22<05:03,  1.10it/s]

Batch 90/425 | Loss: 0.2256


Epoch 7/10:  24%|██▍       | 101/425 [01:31<04:53,  1.11it/s]

Batch 100/425 | Loss: 0.3097


Epoch 7/10:  26%|██▌       | 111/425 [01:40<04:43,  1.11it/s]

Batch 110/425 | Loss: 0.1602


Epoch 7/10:  28%|██▊       | 121/425 [01:49<04:43,  1.07it/s]

Batch 120/425 | Loss: 0.2043


Epoch 7/10:  31%|███       | 131/425 [01:58<04:27,  1.10it/s]

Batch 130/425 | Loss: 0.3306


Epoch 7/10:  33%|███▎      | 141/425 [02:07<04:14,  1.12it/s]

Batch 140/425 | Loss: 0.3010


Epoch 7/10:  36%|███▌      | 151/425 [02:16<04:09,  1.10it/s]

Batch 150/425 | Loss: 0.2933


Epoch 7/10:  38%|███▊      | 161/425 [02:25<03:57,  1.11it/s]

Batch 160/425 | Loss: 0.3721


Epoch 7/10:  40%|████      | 171/425 [02:35<03:50,  1.10it/s]

Batch 170/425 | Loss: 0.2453


Epoch 7/10:  43%|████▎     | 181/425 [02:44<03:42,  1.09it/s]

Batch 180/425 | Loss: 0.2954


Epoch 7/10:  45%|████▍     | 191/425 [02:53<03:36,  1.08it/s]

Batch 190/425 | Loss: 0.1974


Epoch 7/10:  47%|████▋     | 201/425 [03:02<03:27,  1.08it/s]

Batch 200/425 | Loss: 0.3532


Epoch 7/10:  50%|████▉     | 211/425 [03:12<03:15,  1.09it/s]

Batch 210/425 | Loss: 0.3804


Epoch 7/10:  52%|█████▏    | 221/425 [03:21<03:07,  1.09it/s]

Batch 220/425 | Loss: 0.5027


Epoch 7/10:  54%|█████▍    | 231/425 [03:30<03:00,  1.08it/s]

Batch 230/425 | Loss: 0.1836


Epoch 7/10:  57%|█████▋    | 241/425 [03:39<02:47,  1.10it/s]

Batch 240/425 | Loss: 0.3240


Epoch 7/10:  59%|█████▉    | 251/425 [03:48<02:42,  1.07it/s]

Batch 250/425 | Loss: 0.3356


Epoch 7/10:  61%|██████▏   | 261/425 [03:58<02:32,  1.08it/s]

Batch 260/425 | Loss: 0.2708


Epoch 7/10:  64%|██████▍   | 271/425 [04:08<02:32,  1.01it/s]

Batch 270/425 | Loss: 0.4026


Epoch 7/10:  66%|██████▌   | 281/425 [04:17<02:17,  1.05it/s]

Batch 280/425 | Loss: 0.2235


Epoch 7/10:  68%|██████▊   | 291/425 [04:27<02:06,  1.06it/s]

Batch 290/425 | Loss: 0.5082


Epoch 7/10:  71%|███████   | 301/425 [04:36<01:58,  1.05it/s]

Batch 300/425 | Loss: 0.1687


Epoch 7/10:  73%|███████▎  | 311/425 [04:46<01:48,  1.05it/s]

Batch 310/425 | Loss: 0.3126


Epoch 7/10:  76%|███████▌  | 321/425 [04:55<01:40,  1.04it/s]

Batch 320/425 | Loss: 0.1376


Epoch 7/10:  78%|███████▊  | 331/425 [05:05<01:29,  1.05it/s]

Batch 330/425 | Loss: 0.2211


Epoch 7/10:  80%|████████  | 341/425 [05:15<01:21,  1.04it/s]

Batch 340/425 | Loss: 0.1537


Epoch 7/10:  83%|████████▎ | 351/425 [05:24<01:09,  1.06it/s]

Batch 350/425 | Loss: 0.3482


Epoch 7/10:  85%|████████▍ | 361/425 [05:34<01:00,  1.05it/s]

Batch 360/425 | Loss: 0.4401


Epoch 7/10:  87%|████████▋ | 371/425 [05:43<00:51,  1.05it/s]

Batch 370/425 | Loss: 0.1552


Epoch 7/10:  90%|████████▉ | 381/425 [05:53<00:42,  1.03it/s]

Batch 380/425 | Loss: 0.2918


Epoch 7/10:  92%|█████████▏| 391/425 [06:03<00:33,  1.03it/s]

Batch 390/425 | Loss: 0.3241


Epoch 7/10:  94%|█████████▍| 401/425 [06:12<00:23,  1.03it/s]

Batch 400/425 | Loss: 0.2656


Epoch 7/10:  97%|█████████▋| 411/425 [06:22<00:13,  1.04it/s]

Batch 410/425 | Loss: 0.3315


Epoch 7/10:  99%|█████████▉| 421/425 [06:32<00:03,  1.05it/s]

Batch 420/425 | Loss: 0.1629


Epoch 7/10: 100%|██████████| 425/425 [06:35<00:00,  1.08it/s]


Epoch 7/10 | Avg Loss: 0.2606 | Accuracy: 94.64%



Epoch 8/10:   0%|          | 1/425 [00:00<06:17,  1.12it/s]

Batch 0/425 | Loss: 0.1573


Epoch 8/10:   3%|▎         | 11/425 [00:10<06:36,  1.04it/s]

Batch 10/425 | Loss: 0.2371


Epoch 8/10:   5%|▍         | 21/425 [00:19<06:20,  1.06it/s]

Batch 20/425 | Loss: 0.2456


Epoch 8/10:   7%|▋         | 31/425 [00:28<05:59,  1.10it/s]

Batch 30/425 | Loss: 0.1894


Epoch 8/10:  10%|▉         | 41/425 [00:37<05:50,  1.09it/s]

Batch 40/425 | Loss: 0.1423


Epoch 8/10:  12%|█▏        | 51/425 [00:46<05:36,  1.11it/s]

Batch 50/425 | Loss: 0.1126


Epoch 8/10:  14%|█▍        | 61/425 [00:55<05:24,  1.12it/s]

Batch 60/425 | Loss: 0.1871


Epoch 8/10:  17%|█▋        | 71/425 [01:04<05:14,  1.12it/s]

Batch 70/425 | Loss: 0.3192


Epoch 8/10:  19%|█▉        | 81/425 [01:13<05:19,  1.08it/s]

Batch 80/425 | Loss: 0.1647


Epoch 8/10:  21%|██▏       | 91/425 [01:22<05:02,  1.10it/s]

Batch 90/425 | Loss: 0.1918


Epoch 8/10:  24%|██▍       | 101/425 [01:32<05:07,  1.05it/s]

Batch 100/425 | Loss: 0.2553


Epoch 8/10:  26%|██▌       | 111/425 [01:41<04:47,  1.09it/s]

Batch 110/425 | Loss: 0.2323


Epoch 8/10:  28%|██▊       | 121/425 [01:50<04:39,  1.09it/s]

Batch 120/425 | Loss: 0.3513


Epoch 8/10:  31%|███       | 131/425 [01:59<04:23,  1.12it/s]

Batch 130/425 | Loss: 0.2947


Epoch 8/10:  33%|███▎      | 141/425 [02:08<04:13,  1.12it/s]

Batch 140/425 | Loss: 0.0981


Epoch 8/10:  36%|███▌      | 151/425 [02:17<04:05,  1.12it/s]

Batch 150/425 | Loss: 0.3095


Epoch 8/10:  38%|███▊      | 161/425 [02:26<03:57,  1.11it/s]

Batch 160/425 | Loss: 0.2528


Epoch 8/10:  40%|████      | 171/425 [02:35<03:51,  1.10it/s]

Batch 170/425 | Loss: 0.2008


Epoch 8/10:  43%|████▎     | 181/425 [02:44<03:42,  1.10it/s]

Batch 180/425 | Loss: 0.1659


Epoch 8/10:  45%|████▍     | 191/425 [02:53<03:30,  1.11it/s]

Batch 190/425 | Loss: 0.1742


Epoch 8/10:  47%|████▋     | 201/425 [03:02<03:19,  1.12it/s]

Batch 200/425 | Loss: 0.3463


Epoch 8/10:  50%|████▉     | 211/425 [03:11<03:12,  1.11it/s]

Batch 210/425 | Loss: 0.1633


Epoch 8/10:  52%|█████▏    | 221/425 [03:20<03:05,  1.10it/s]

Batch 220/425 | Loss: 0.3108


Epoch 8/10:  54%|█████▍    | 231/425 [03:29<02:56,  1.10it/s]

Batch 230/425 | Loss: 0.1516


Epoch 8/10:  57%|█████▋    | 241/425 [03:38<02:49,  1.09it/s]

Batch 240/425 | Loss: 0.3428


Epoch 8/10:  59%|█████▉    | 251/425 [03:48<02:39,  1.09it/s]

Batch 250/425 | Loss: 0.1794


Epoch 8/10:  61%|██████▏   | 261/425 [03:57<02:28,  1.10it/s]

Batch 260/425 | Loss: 0.2640


Epoch 8/10:  64%|██████▍   | 271/425 [04:06<02:20,  1.10it/s]

Batch 270/425 | Loss: 0.4436


Epoch 8/10:  66%|██████▌   | 281/425 [04:15<02:14,  1.07it/s]

Batch 280/425 | Loss: 0.2390


Epoch 8/10:  68%|██████▊   | 291/425 [04:25<02:08,  1.04it/s]

Batch 290/425 | Loss: 0.1250


Epoch 8/10:  71%|███████   | 301/425 [04:34<02:01,  1.02it/s]

Batch 300/425 | Loss: 0.1214


Epoch 8/10:  73%|███████▎  | 311/425 [04:44<01:47,  1.06it/s]

Batch 310/425 | Loss: 0.4228


Epoch 8/10:  76%|███████▌  | 321/425 [04:53<01:37,  1.07it/s]

Batch 320/425 | Loss: 0.2142


Epoch 8/10:  78%|███████▊  | 331/425 [05:03<01:30,  1.04it/s]

Batch 330/425 | Loss: 0.1681


Epoch 8/10:  80%|████████  | 341/425 [05:12<01:19,  1.05it/s]

Batch 340/425 | Loss: 0.0674


Epoch 8/10:  83%|████████▎ | 351/425 [05:22<01:11,  1.03it/s]

Batch 350/425 | Loss: 0.5908


Epoch 8/10:  85%|████████▍ | 361/425 [05:31<01:01,  1.04it/s]

Batch 360/425 | Loss: 0.3860


Epoch 8/10:  87%|████████▋ | 371/425 [05:41<00:50,  1.07it/s]

Batch 370/425 | Loss: 0.1783


Epoch 8/10:  90%|████████▉ | 381/425 [05:50<00:40,  1.08it/s]

Batch 380/425 | Loss: 0.1580


Epoch 8/10:  92%|█████████▏| 391/425 [06:00<00:32,  1.06it/s]

Batch 390/425 | Loss: 0.3081


Epoch 8/10:  94%|█████████▍| 401/425 [06:09<00:22,  1.05it/s]

Batch 400/425 | Loss: 0.2386


Epoch 8/10:  97%|█████████▋| 411/425 [06:19<00:13,  1.05it/s]

Batch 410/425 | Loss: 0.3442


Epoch 8/10:  99%|█████████▉| 421/425 [06:28<00:03,  1.04it/s]

Batch 420/425 | Loss: 0.1879


Epoch 8/10: 100%|██████████| 425/425 [06:31<00:00,  1.08it/s]


Epoch 8/10 | Avg Loss: 0.2479 | Accuracy: 94.90%



Epoch 9/10:   0%|          | 1/425 [00:00<06:24,  1.10it/s]

Batch 0/425 | Loss: 0.1487


Epoch 9/10:   3%|▎         | 11/425 [00:09<06:10,  1.12it/s]

Batch 10/425 | Loss: 0.1972


Epoch 9/10:   5%|▍         | 21/425 [00:18<06:00,  1.12it/s]

Batch 20/425 | Loss: 0.2721


Epoch 9/10:   7%|▋         | 31/425 [00:28<06:23,  1.03it/s]

Batch 30/425 | Loss: 0.1215


Epoch 9/10:  10%|▉         | 41/425 [00:37<05:43,  1.12it/s]

Batch 40/425 | Loss: 0.2545


Epoch 9/10:  12%|█▏        | 51/425 [00:46<05:35,  1.12it/s]

Batch 50/425 | Loss: 0.3925


Epoch 9/10:  14%|█▍        | 61/425 [00:55<05:37,  1.08it/s]

Batch 60/425 | Loss: 0.1796


Epoch 9/10:  17%|█▋        | 71/425 [01:04<05:15,  1.12it/s]

Batch 70/425 | Loss: 0.1054


Epoch 9/10:  19%|█▉        | 81/425 [01:13<05:08,  1.12it/s]

Batch 80/425 | Loss: 0.3030


Epoch 9/10:  21%|██▏       | 91/425 [01:22<04:57,  1.12it/s]

Batch 90/425 | Loss: 0.2073


Epoch 9/10:  24%|██▍       | 101/425 [01:31<04:52,  1.11it/s]

Batch 100/425 | Loss: 0.1783


Epoch 9/10:  26%|██▌       | 111/425 [01:40<04:42,  1.11it/s]

Batch 110/425 | Loss: 0.2059


Epoch 9/10:  28%|██▊       | 121/425 [01:49<04:33,  1.11it/s]

Batch 120/425 | Loss: 0.1560


Epoch 9/10:  31%|███       | 131/425 [01:58<04:32,  1.08it/s]

Batch 130/425 | Loss: 0.2625


Epoch 9/10:  33%|███▎      | 141/425 [02:07<04:15,  1.11it/s]

Batch 140/425 | Loss: 0.1772


Epoch 9/10:  36%|███▌      | 151/425 [02:16<04:06,  1.11it/s]

Batch 150/425 | Loss: 0.2286


Epoch 9/10:  38%|███▊      | 161/425 [02:25<03:57,  1.11it/s]

Batch 160/425 | Loss: 0.3083


Epoch 9/10:  40%|████      | 171/425 [02:34<03:46,  1.12it/s]

Batch 170/425 | Loss: 0.1517


Epoch 9/10:  43%|████▎     | 181/425 [02:43<03:40,  1.11it/s]

Batch 180/425 | Loss: 0.1470


Epoch 9/10:  45%|████▍     | 191/425 [02:52<03:29,  1.12it/s]

Batch 190/425 | Loss: 0.2778


Epoch 9/10:  47%|████▋     | 201/425 [03:01<03:24,  1.09it/s]

Batch 200/425 | Loss: 0.1978


Epoch 9/10:  50%|████▉     | 211/425 [03:10<03:13,  1.11it/s]

Batch 210/425 | Loss: 0.2102


Epoch 9/10:  52%|█████▏    | 221/425 [03:19<03:01,  1.12it/s]

Batch 220/425 | Loss: 0.1355


Epoch 9/10:  54%|█████▍    | 231/425 [03:28<02:54,  1.11it/s]

Batch 230/425 | Loss: 0.2224


Epoch 9/10:  57%|█████▋    | 241/425 [03:37<02:46,  1.10it/s]

Batch 240/425 | Loss: 0.1176


Epoch 9/10:  59%|█████▉    | 251/425 [03:46<02:41,  1.08it/s]

Batch 250/425 | Loss: 0.0913


Epoch 9/10:  61%|██████▏   | 261/425 [03:55<02:27,  1.11it/s]

Batch 260/425 | Loss: 0.2178


Epoch 9/10:  64%|██████▍   | 271/425 [04:04<02:20,  1.09it/s]

Batch 270/425 | Loss: 0.3007


Epoch 9/10:  66%|██████▌   | 281/425 [04:13<02:11,  1.10it/s]

Batch 280/425 | Loss: 0.4198


Epoch 9/10:  68%|██████▊   | 291/425 [04:22<02:01,  1.10it/s]

Batch 290/425 | Loss: 0.3053


Epoch 9/10:  71%|███████   | 301/425 [04:31<01:50,  1.12it/s]

Batch 300/425 | Loss: 0.5313


Epoch 9/10:  73%|███████▎  | 311/425 [04:41<01:43,  1.10it/s]

Batch 310/425 | Loss: 0.2700


Epoch 9/10:  76%|███████▌  | 321/425 [04:50<01:34,  1.09it/s]

Batch 320/425 | Loss: 0.4518


Epoch 9/10:  78%|███████▊  | 331/425 [04:59<01:26,  1.09it/s]

Batch 330/425 | Loss: 0.2271


Epoch 9/10:  80%|████████  | 341/425 [05:08<01:18,  1.07it/s]

Batch 340/425 | Loss: 0.3741


Epoch 9/10:  83%|████████▎ | 351/425 [05:17<01:07,  1.10it/s]

Batch 350/425 | Loss: 0.2169


Epoch 9/10:  85%|████████▍ | 361/425 [05:26<00:58,  1.09it/s]

Batch 360/425 | Loss: 0.1271


Epoch 9/10:  87%|████████▋ | 371/425 [05:36<00:50,  1.08it/s]

Batch 370/425 | Loss: 0.0798


Epoch 9/10:  90%|████████▉ | 381/425 [05:45<00:40,  1.08it/s]

Batch 380/425 | Loss: 0.3143


Epoch 9/10:  92%|█████████▏| 391/425 [05:54<00:31,  1.08it/s]

Batch 390/425 | Loss: 0.2568


Epoch 9/10:  94%|█████████▍| 401/425 [06:04<00:22,  1.05it/s]

Batch 400/425 | Loss: 0.0565


Epoch 9/10:  97%|█████████▋| 411/425 [06:13<00:13,  1.06it/s]

Batch 410/425 | Loss: 0.1880


Epoch 9/10:  99%|█████████▉| 421/425 [06:23<00:03,  1.08it/s]

Batch 420/425 | Loss: 0.0920


Epoch 9/10: 100%|██████████| 425/425 [06:26<00:00,  1.10it/s]


Epoch 9/10 | Avg Loss: 0.2491 | Accuracy: 94.85%



Epoch 10/10:   0%|          | 1/425 [00:00<06:20,  1.12it/s]

Batch 0/425 | Loss: 0.2078


Epoch 10/10:   3%|▎         | 11/425 [00:09<06:14,  1.11it/s]

Batch 10/425 | Loss: 0.3056


Epoch 10/10:   5%|▍         | 21/425 [00:18<06:08,  1.09it/s]

Batch 20/425 | Loss: 0.1760


Epoch 10/10:   7%|▋         | 31/425 [00:28<06:06,  1.07it/s]

Batch 30/425 | Loss: 0.2354


Epoch 10/10:  10%|▉         | 41/425 [00:37<05:58,  1.07it/s]

Batch 40/425 | Loss: 0.2566


Epoch 10/10:  12%|█▏        | 51/425 [00:47<06:05,  1.02it/s]

Batch 50/425 | Loss: 0.1014


Epoch 10/10:  14%|█▍        | 61/425 [00:56<05:39,  1.07it/s]

Batch 60/425 | Loss: 0.0566


Epoch 10/10:  17%|█▋        | 71/425 [01:05<05:24,  1.09it/s]

Batch 70/425 | Loss: 0.2495


Epoch 10/10:  19%|█▉        | 81/425 [01:14<05:13,  1.10it/s]

Batch 80/425 | Loss: 0.2563


Epoch 10/10:  21%|██▏       | 91/425 [01:24<05:58,  1.07s/it]

Batch 90/425 | Loss: 0.2068


Epoch 10/10:  24%|██▍       | 101/425 [01:33<04:51,  1.11it/s]

Batch 100/425 | Loss: 0.2785


Epoch 10/10:  26%|██▌       | 111/425 [01:42<04:52,  1.08it/s]

Batch 110/425 | Loss: 0.3886


Epoch 10/10:  28%|██▊       | 121/425 [01:52<04:40,  1.08it/s]

Batch 120/425 | Loss: 0.1558


Epoch 10/10:  31%|███       | 131/425 [02:01<04:26,  1.10it/s]

Batch 130/425 | Loss: 0.1602


Epoch 10/10:  33%|███▎      | 141/425 [02:10<04:14,  1.11it/s]

Batch 140/425 | Loss: 0.1683


Epoch 10/10:  36%|███▌      | 151/425 [02:19<04:08,  1.10it/s]

Batch 150/425 | Loss: 0.4722


Epoch 10/10:  38%|███▊      | 161/425 [02:28<04:00,  1.10it/s]

Batch 160/425 | Loss: 0.2156


Epoch 10/10:  40%|████      | 171/425 [02:37<03:52,  1.09it/s]

Batch 170/425 | Loss: 0.1630


Epoch 10/10:  43%|████▎     | 181/425 [02:47<03:45,  1.08it/s]

Batch 180/425 | Loss: 0.1013


Epoch 10/10:  45%|████▍     | 191/425 [02:56<03:34,  1.09it/s]

Batch 190/425 | Loss: 0.2996


Epoch 10/10:  47%|████▋     | 201/425 [03:05<03:22,  1.11it/s]

Batch 200/425 | Loss: 0.2346


Epoch 10/10:  50%|████▉     | 211/425 [03:14<03:16,  1.09it/s]

Batch 210/425 | Loss: 0.2688


Epoch 10/10:  52%|█████▏    | 221/425 [03:23<03:08,  1.08it/s]

Batch 220/425 | Loss: 0.2457


Epoch 10/10:  54%|█████▍    | 231/425 [03:33<02:59,  1.08it/s]

Batch 230/425 | Loss: 0.3045


Epoch 10/10:  57%|█████▋    | 241/425 [03:42<02:52,  1.07it/s]

Batch 240/425 | Loss: 0.3375


Epoch 10/10:  59%|█████▉    | 251/425 [03:51<02:39,  1.09it/s]

Batch 250/425 | Loss: 0.2602


Epoch 10/10:  61%|██████▏   | 261/425 [04:01<02:40,  1.02it/s]

Batch 260/425 | Loss: 0.2676


Epoch 10/10:  64%|██████▍   | 271/425 [04:10<02:25,  1.05it/s]

Batch 270/425 | Loss: 0.3607


Epoch 10/10:  66%|██████▌   | 281/425 [04:20<02:14,  1.07it/s]

Batch 280/425 | Loss: 0.2845


Epoch 10/10:  68%|██████▊   | 291/425 [04:29<02:07,  1.05it/s]

Batch 290/425 | Loss: 0.1819


Epoch 10/10:  71%|███████   | 301/425 [04:38<01:55,  1.08it/s]

Batch 300/425 | Loss: 0.3632


Epoch 10/10:  73%|███████▎  | 311/425 [04:48<01:45,  1.08it/s]

Batch 310/425 | Loss: 0.1218


Epoch 10/10:  76%|███████▌  | 321/425 [04:57<01:35,  1.08it/s]

Batch 320/425 | Loss: 0.2250


Epoch 10/10:  78%|███████▊  | 331/425 [05:06<01:26,  1.09it/s]

Batch 330/425 | Loss: 0.2580


Epoch 10/10:  80%|████████  | 341/425 [05:15<01:18,  1.07it/s]

Batch 340/425 | Loss: 0.2154


Epoch 10/10:  83%|████████▎ | 351/425 [05:25<01:10,  1.06it/s]

Batch 350/425 | Loss: 0.2085


Epoch 10/10:  85%|████████▍ | 361/425 [05:35<01:02,  1.03it/s]

Batch 360/425 | Loss: 0.1565


Epoch 10/10:  87%|████████▋ | 371/425 [05:44<00:52,  1.03it/s]

Batch 370/425 | Loss: 0.1583


Epoch 10/10:  90%|████████▉ | 381/425 [05:54<00:41,  1.06it/s]

Batch 380/425 | Loss: 0.2456


Epoch 10/10:  92%|█████████▏| 391/425 [06:03<00:32,  1.05it/s]

Batch 390/425 | Loss: 0.1201


Epoch 10/10:  94%|█████████▍| 401/425 [06:13<00:23,  1.04it/s]

Batch 400/425 | Loss: 0.3020


Epoch 10/10:  97%|█████████▋| 411/425 [06:23<00:13,  1.05it/s]

Batch 410/425 | Loss: 0.1710


Epoch 10/10:  99%|█████████▉| 421/425 [06:32<00:03,  1.04it/s]

Batch 420/425 | Loss: 0.2117


Epoch 10/10: 100%|██████████| 425/425 [06:35<00:00,  1.07it/s]


Epoch 10/10 | Avg Loss: 0.2317 | Accuracy: 95.19%

Model saved to /kaggle/working/jigsaw_encoder_3layer.pth


In [12]:
from sklearn.metrics import accuracy_score
from tqdm import tqdm

class Downstream(nn.Module):
    def __init__(self, num_classes=38):
        super(Downstream, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),     
            nn.ReLU(),
            nn.MaxPool2d(2),                                

            nn.Conv2d(64, 128, kernel_size=3, padding=1),   
            nn.ReLU(),
            nn.MaxPool2d(2),                                

            nn.Conv2d(128, 128, kernel_size=3, padding=1),   
            nn.ReLU(),
            nn.MaxPool2d(4),                                

            nn.Conv2d(128, 256, kernel_size=3, padding=1),  
            nn.ReLU(),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),  
            nn.ReLU(),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),  
            nn.ReLU(),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),  
            nn.ReLU(),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),  
            nn.ReLU(),

            nn.MaxPool2d(4)                                 
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),                                   
            nn.Linear(8192, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 38)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = Downstream(38)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [17]:
from collections import OrderedDict

statedictory = torch.load("/kaggle/input/pretextweights/pretextweights.pth")

newstatedirectory = OrderedDict()
for i, j in statedictory.items():
    if i.startswith("encoder.encoder."):  
        newkey = i.replace("encoder.encoder.", "encoder.")  
        newstatedirectory[newkey] = j

pretextmodel = Pretext()
pretextmodel.load_state_dict(newstatedirectory)

downstreammodel = Downstream(num_classes=38)

downstreammodel.features[0].weight.data = pretextmodel.encoder[0].weight.data.clone()
downstreammodel.features[0].bias.data   = pretextmodel.encoder[0].bias.data.clone()

downstreammodel.features[3].weight.data = pretextmodel.encoder[3].weight.data.clone()
downstreammodel.features[3].bias.data   = pretextmodel.encoder[3].bias.data.clone()

downstreammodel.features[6].weight.data = pretextmodel.encoder[6].weight.data.clone()
downstreammodel.features[6].bias.data   = pretextmodel.encoder[6].bias.data.clone()

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchsize = 32
e = 10
learningrate = 1e-3

transformtensor = transforms.Compose([
    transforms.ToTensor()
])

traindirectory = "/kaggle/input/jigsaw-supervised/split_dataset/fold1/train"
valdirectory   = "/kaggle/input/jigsaw-supervised/split_dataset/fold1/val"


train_dataset = datasets.ImageFolder(traindirectory, transform=transformtensor)
val_dataset   = datasets.ImageFolder(valdirectory, transform=transformtensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

model = downstreammodel.to(device)

c = nn.CrossEntropyLoss()
o = torch.optim.Adam(model.parameters(), lr=learningrate)

for epoch in range(e):
    print(f"\nEpoch {epoch + 1}/{e}")
    model.train()
    trainpredictions,trainlabels = [],[]
    for x, y in tqdm(train_loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        o.zero_grad()
        out = model(x)
        loss = c(out, y)
        loss.backward()
        o.step()

        predictedclasses=out.argmax(1)          
        predictedclasses=predictedclasses.cpu()  
        predictedclasses=predictedclasses.numpy() 
        trainpredictions.extend(predictedclasses)   
        
        truelabels=y.cpu().numpy()
        trainlabels.extend(truelabels)   

    trainaccuracy= accuracy_score(trainlabels, trainpredictions)
    print(f"Train Accuracy:{trainaccuracy:.4f}")

    model.eval()
    valpredictions,vallabels = [],[]
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Validating"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            guessclasses = out.argmax(1)
            guessclasses = guessclasses.cpu()
            guessclasses = guessclasses.numpy()
            valpredictions.extend(guessclasses)
            
            trueclasses = y.cpu().numpy()
            vallabels.extend(trueclasses)


    valaccuracy = accuracy_score(vallabels, valpredictions)
    print(f"Validation Accuracy: {valaccuracy:.4f}")
torch.save(model.state_dict(), "downstream_fold1.pth")


Epoch 1/10


Training: 100%|██████████| 905/905 [03:32<00:00,  4.25it/s]


Train Accuracy:0.3532


Validating: 100%|██████████| 453/453 [01:14<00:00,  6.07it/s]


Validation Accuracy: 0.5827

Epoch 2/10


Training: 100%|██████████| 905/905 [02:21<00:00,  6.39it/s]


Train Accuracy:0.6026


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.43it/s]


Validation Accuracy: 0.7050

Epoch 3/10


Training: 100%|██████████| 905/905 [02:20<00:00,  6.44it/s]


Train Accuracy:0.6993


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.38it/s]


Validation Accuracy: 0.7603

Epoch 4/10


Training: 100%|██████████| 905/905 [02:20<00:00,  6.44it/s]


Train Accuracy:0.7619


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.41it/s]


Validation Accuracy: 0.8040

Epoch 5/10


Training: 100%|██████████| 905/905 [02:20<00:00,  6.43it/s]


Train Accuracy:0.7978


Validating: 100%|██████████| 453/453 [00:41<00:00, 10.98it/s]


Validation Accuracy: 0.8222

Epoch 6/10


Training: 100%|██████████| 905/905 [02:22<00:00,  6.35it/s]


Train Accuracy:0.8275


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.38it/s]


Validation Accuracy: 0.7807

Epoch 7/10


Training: 100%|██████████| 905/905 [03:14<00:00,  4.66it/s]


Train Accuracy:0.8503


Validating: 100%|██████████| 453/453 [01:10<00:00,  6.42it/s]


Validation Accuracy: 0.8001

Epoch 8/10


Training: 100%|██████████| 905/905 [03:16<00:00,  4.60it/s]


Train Accuracy:0.8668


Validating: 100%|██████████| 453/453 [00:51<00:00,  8.82it/s]


Validation Accuracy: 0.8551

Epoch 9/10


Training: 100%|██████████| 905/905 [02:24<00:00,  6.25it/s]


Train Accuracy:0.8837


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.60it/s]


Validation Accuracy: 0.8553

Epoch 10/10


Training: 100%|██████████| 905/905 [02:21<00:00,  6.38it/s]


Train Accuracy:0.8967


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.72it/s]


Validation Accuracy: 0.8566


In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchsize = 32
e = 10
learningrate = 1e-3

transformtensor = transforms.Compose([
    transforms.ToTensor()
])

traindirectory = "/kaggle/input/jigsaw-supervised/split_dataset/fold2/train"
valdirectory   = "/kaggle/input/jigsaw-supervised/split_dataset/fold2/val"

train_dataset = datasets.ImageFolder(traindirectory, transform=transformtensor)
val_dataset   = datasets.ImageFolder(valdirectory, transform=transformtensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

model = downstreammodel.to(device)

model.load_state_dict(torch.load("/kaggle/working/downstream_fold1.pth", map_location=device))

c = nn.CrossEntropyLoss()
o = torch.optim.Adam(model.parameters(), lr=learningrate)

for epoch in range(e):
    print(f"\nEpoch {epoch + 1}/{e}")
    model.train()
    trainpredictions, trainlabels = [], []
    for x, y in tqdm(train_loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        o.zero_grad()
        out = model(x)
        loss = c(out, y)
        loss.backward()
        o.step()

        predictedclasses = out.argmax(1).cpu().numpy()
        trainpredictions.extend(predictedclasses)
        trainlabels.extend(y.cpu().numpy())

    trainaccuracy = accuracy_score(trainlabels, trainpredictions)
    print(f"Train Accuracy: {trainaccuracy:.4f}")

    model.eval()
    valpredictions, vallabels = [], []
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Validating"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            guessclasses = out.argmax(1)
            guessclasses = guessclasses.cpu()
            guessclasses = guessclasses.numpy()
            valpredictions.extend(guessclasses)
            
            trueclasses = y.cpu().numpy()
            vallabels.extend(trueclasses)

    valaccuracy = accuracy_score(vallabels, valpredictions)
    print(f"Validation Accuracy: {valaccuracy:.4f}")

torch.save(model.state_dict(), "downstream_fold2.pth")



Epoch 1/10


Training: 100%|██████████| 906/906 [04:34<00:00,  3.30it/s]


Train Accuracy: 0.8693


Validating: 100%|██████████| 453/453 [01:23<00:00,  5.42it/s]


Validation Accuracy: 0.9182

Epoch 2/10


Training: 100%|██████████| 906/906 [02:27<00:00,  6.16it/s]


Train Accuracy: 0.8894


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.41it/s]


Validation Accuracy: 0.9388

Epoch 3/10


Training: 100%|██████████| 906/906 [02:22<00:00,  6.36it/s]


Train Accuracy: 0.9055


Validating: 100%|██████████| 453/453 [00:40<00:00, 11.30it/s]


Validation Accuracy: 0.9342

Epoch 4/10


Training: 100%|██████████| 906/906 [02:22<00:00,  6.34it/s]


Train Accuracy: 0.9153


Validating: 100%|██████████| 453/453 [00:40<00:00, 11.12it/s]


Validation Accuracy: 0.9230

Epoch 5/10


Training: 100%|██████████| 906/906 [02:26<00:00,  6.17it/s]


Train Accuracy: 0.9187


Validating: 100%|██████████| 453/453 [00:41<00:00, 10.92it/s]


Validation Accuracy: 0.9285

Epoch 6/10


Training: 100%|██████████| 906/906 [02:26<00:00,  6.20it/s]


Train Accuracy: 0.9300


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.67it/s]


Validation Accuracy: 0.9268

Epoch 7/10


Training: 100%|██████████| 906/906 [02:21<00:00,  6.41it/s]


Train Accuracy: 0.9338


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.45it/s]


Validation Accuracy: 0.9132

Epoch 8/10


Training: 100%|██████████| 906/906 [02:19<00:00,  6.47it/s]


Train Accuracy: 0.9347


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.87it/s]


Validation Accuracy: 0.9174

Epoch 9/10


Training: 100%|██████████| 906/906 [02:20<00:00,  6.45it/s]


Train Accuracy: 0.9424


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.84it/s]


Validation Accuracy: 0.9054

Epoch 10/10


Training: 100%|██████████| 906/906 [02:21<00:00,  6.38it/s]


Train Accuracy: 0.9474


Validating: 100%|██████████| 453/453 [00:41<00:00, 11.04it/s]

Validation Accuracy: 0.9308





In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchsize = 32
e = 10
learningrate = 1e-3

transformtensor = transforms.Compose([
    transforms.ToTensor()
])

traindirectory = "/kaggle/input/jigsaw-supervised/split_dataset/fold3/train"
valdirectory   = "/kaggle/input/jigsaw-supervised/split_dataset/fold3/val"

train_dataset = datasets.ImageFolder(traindirectory, transform=transformtensor)
val_dataset   = datasets.ImageFolder(valdirectory, transform=transformtensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

model = downstreammodel.to(device)

model.load_state_dict(torch.load("/kaggle/working/downstream_fold2.pth", map_location=device))

c = nn.CrossEntropyLoss()
o = torch.optim.Adam(model.parameters(), lr=learningrate)

for epoch in range(e):
    print(f"\nEpoch {epoch + 1}/{e}")
    model.train()
    trainpredictions, trainlabels = [], []
    for x, y in tqdm(train_loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        o.zero_grad()
        out = model(x)
        loss = c(out, y)
        loss.backward()
        o.step()

        predictedclasses = out.argmax(1).cpu().numpy()
        trainpredictions.extend(predictedclasses)
        trainlabels.extend(y.cpu().numpy())

    trainaccuracy = accuracy_score(trainlabels, trainpredictions)
    print(f"Train Accuracy: {trainaccuracy:.4f}")

    model.eval()
    valpredictions, vallabels = [], []
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Validating"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            guessclasses = out.argmax(1)
            guessclasses = guessclasses.cpu()
            guessclasses = guessclasses.numpy()
            valpredictions.extend(guessclasses)
            
            trueclasses = y.cpu().numpy()
            vallabels.extend(trueclasses)

    valaccuracy = accuracy_score(vallabels, valpredictions)
    print(f"Validation Accuracy: {valaccuracy:.4f}")

torch.save(model.state_dict(), "downstream_fold3.pth")



Epoch 1/10


Training: 100%|██████████| 906/906 [05:08<00:00,  2.94it/s]


Train Accuracy: 0.9160


Validating: 100%|██████████| 453/453 [02:01<00:00,  3.73it/s]


Validation Accuracy: 0.9643

Epoch 2/10


Training: 100%|██████████| 906/906 [02:25<00:00,  6.23it/s]


Train Accuracy: 0.9304


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.64it/s]


Validation Accuracy: 0.9703

Epoch 3/10


Training: 100%|██████████| 906/906 [02:20<00:00,  6.43it/s]


Train Accuracy: 0.9422


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.47it/s]


Validation Accuracy: 0.9736

Epoch 4/10


Training: 100%|██████████| 906/906 [02:22<00:00,  6.38it/s]


Train Accuracy: 0.9419


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.71it/s]


Validation Accuracy: 0.9751

Epoch 5/10


Training: 100%|██████████| 906/906 [02:21<00:00,  6.41it/s]


Train Accuracy: 0.9416


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.46it/s]


Validation Accuracy: 0.9489

Epoch 6/10


Training: 100%|██████████| 906/906 [02:21<00:00,  6.41it/s]


Train Accuracy: 0.9498


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.42it/s]


Validation Accuracy: 0.9547

Epoch 7/10


Training: 100%|██████████| 906/906 [02:22<00:00,  6.35it/s]


Train Accuracy: 0.9481


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.38it/s]


Validation Accuracy: 0.9552

Epoch 8/10


Training: 100%|██████████| 906/906 [02:20<00:00,  6.44it/s]


Train Accuracy: 0.9543


Validating: 100%|██████████| 453/453 [00:39<00:00, 11.45it/s]


Validation Accuracy: 0.9600

Epoch 9/10


Training: 100%|██████████| 906/906 [02:20<00:00,  6.45it/s]


Train Accuracy: 0.9597


Validating: 100%|██████████| 453/453 [00:38<00:00, 11.65it/s]


Validation Accuracy: 0.9398

Epoch 10/10


Training: 100%|██████████| 906/906 [02:19<00:00,  6.47it/s]


Train Accuracy: 0.9567


Validating: 100%|██████████| 453/453 [00:40<00:00, 11.24it/s]


Validation Accuracy: 0.9213


In [31]:
testdirectory = "/kaggle/input/jigsaw-supervised/split_dataset/test"

transformtensor = transforms.Compose([
    transforms.ToTensor()
])

testdataset = datasets.ImageFolder(testdirectory, transform=transformtensor)
testloader = DataLoader(testdataset, batch_size=32, shuffle=False, num_workers=0)

model = downstreammodel.to(device)
model.load_state_dict(torch.load("/kaggle/working/downstream_fold3.pth", map_location=device))
model.eval()

testpredictions, testlabels = [],[]

with torch.no_grad():
    for x, y in tqdm(testloader, desc="Testing"):
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        predictions = outputs.argmax(1)
        predictions=predictions.cpu()
        predictions=predictions.numpy()
        testpredictions.extend(predictions)
        testlabels.extend(y.cpu().numpy())

testaccuracy = accuracy_score(testlabels, testpredictions)
print(f"\nTest Accuracy: {testaccuracy:.4f}")

Testing: 100%|██████████| 340/340 [01:07<00:00,  5.04it/s]


Test Accuracy: 0.8592



