In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score

# ========== CONFIG ========== #
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT_TRAIN = "C:\\Users\\mrinmoy\\Documents\\GitHub\\hackthoncomsys_face_classification_challenge\\Comsys_Hackathon5\\Task_B\\train"  # UPDATE
ROOT_VAL = "C:\\Users\\mrinmoy\\Documents\\GitHub\\hackthoncomsys_face_classification_challenge\\Comsys_Hackathon5\\Task_B\\val"      # UPDATE
N_WAY = 5 ## 5 classes per episode
N_SHOT = 5 ## 5 distorted images per class (we have 7 available)
N_QUERY = 1 ## Number of query images per class as real images per class is 1
EPISODES = 2000
LR = 1e-3
IMG_SIZE = 224
EMBED_DIM = 256

# ========== TRANSFORM ========== #
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ========== DATASET ========== #
class FewShotDataset:
    def __init__(self, root_dir, transform, n_shot=5, n_query=1):
        self.root_dir = root_dir
        self.transform = transform
        self.n_shot = n_shot
        self.n_query = n_query

        self.classes = []
        self.class_to_images = {}

        for cls in os.listdir(root_dir):
            class_path = os.path.join(root_dir, cls)
            if not os.path.isdir(class_path):
                continue

            clean_images = [os.path.join(class_path, img) for img in os.listdir(class_path)
                            if img.endswith('.jpg') and not os.path.isdir(os.path.join(class_path, img))]

            distortion_path = os.path.join(class_path, 'distortion')
            if os.path.exists(distortion_path):
                distorted_images = [os.path.join(distortion_path, img) for img in os.listdir(distortion_path)
                                    if img.endswith('.jpg')]
            else:
                distorted_images = []

            # Only keep classes that have enough images
            if len(clean_images) >= self.n_query and len(distorted_images) >= self.n_shot:
                self.classes.append(cls)
                self.class_to_images[cls] = {'clean': clean_images, 'distorted': distorted_images}

        if len(self.classes) == 0:
            raise Exception("No class has enough images. Please check dataset or adjust n_shot/n_query.")

    def sample_episode(self, n_way=5):
        """
        Sample a single N-way K-shot episode.
        Support: distorted images
        Query: clean images
        """
        if len(self.classes) < n_way:
            raise Exception(f"Not enough classes with sufficient images. Available: {len(self.classes)}, Requested: {n_way}")

        while True:
            selected_classes = random.sample(self.classes, n_way)

            support_images = []
            support_labels = []
            query_images = []
            query_labels = []

            label_mapping = {cls: idx for idx, cls in enumerate(selected_classes)}
            valid_episode = True

            for cls in selected_classes:
                clean_imgs = self.class_to_images[cls]['clean']
                distorted_imgs = self.class_to_images[cls]['distorted']

                if len(clean_imgs) < self.n_query or len(distorted_imgs) < self.n_shot:
                    valid_episode = False
                    break  # Resample episode

                support_samples = random.sample(distorted_imgs, self.n_shot)
                query_samples = random.sample(clean_imgs, self.n_query)

                for img_path in support_samples:
                    img = Image.open(img_path).convert('RGB')
                    support_images.append(self.transform(img))
                    support_labels.append(label_mapping[cls])

                for img_path in query_samples:
                    img = Image.open(img_path).convert('RGB')
                    query_images.append(self.transform(img))
                    query_labels.append(label_mapping[cls])

            if valid_episode and support_images and query_images:
                return (torch.stack(support_images), torch.tensor(support_labels),
                        torch.stack(query_images), torch.tensor(query_labels))

# ========== MODEL ========== #
class ProtoNet(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(base.children())[:-1])
        self.fc = nn.Linear(base.fc.in_features, embed_dim)

    def forward(self, x):
        x = self.features(x).squeeze()
        return self.fc(x)

# ========== PROTOTYPICAL PREDICTION ========== #
def euclidean_dist(a, b):
    n = a.size(0)
    m = b.size(0)
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    return torch.pow(a - b, 2).sum(2)

def predict_protonet(support_embeddings, support_labels, query_embeddings, n_way):
    prototypes = []
    for c in range(n_way):
        class_mask = (support_labels == c)
        class_embeddings = support_embeddings[class_mask]
        class_proto = class_embeddings.mean(0)
        prototypes.append(class_proto)
    prototypes = torch.stack(prototypes)
    dists = euclidean_dist(query_embeddings, prototypes)
    scores = -dists
    return scores

# ========== VALIDATION ========== #
def validate(model, val_dataset, device, n_way=5, n_shot=5, n_query=5, episodes=50):
    model.eval()
    accs = []

    with torch.no_grad():
        for _ in range(episodes):
            support_x, support_y, query_x, query_y = val_dataset.sample_episode(n_way)
            support_x, support_y = support_x.to(device), support_y.to(device)
            query_x, query_y = query_x.to(device), query_y.to(device)

            support_emb = model(support_x)
            query_emb = model(query_x)

            logits = predict_protonet(support_emb, support_y, query_emb, n_way)
            preds = logits.argmax(1)
            acc = (preds == query_y).float().mean().item()
            accs.append(acc)

    return sum(accs) / len(accs)

# ========== TRAINING ========== #
def train_protonet():
    train_dataset = FewShotDataset(ROOT_TRAIN, transform)
    val_dataset = FewShotDataset(ROOT_VAL, transform)

    model = ProtoNet().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    best_val_acc = 0

    for episode in tqdm(range(EPISODES)):
        support_x, support_y, query_x, query_y = train_dataset.sample_episode(N_WAY)
        support_x, support_y = support_x.to(DEVICE), support_y.to(DEVICE)
        query_x, query_y = query_x.to(DEVICE), query_y.to(DEVICE)

        model.train()
        support_emb = model(support_x)
        query_emb = model(query_x)

        logits = predict_protonet(support_emb, support_y, query_emb, N_WAY)
        loss = F.cross_entropy(logits, query_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (episode + 1) % 10 == 0:
            preds = logits.argmax(1)
            acc = (preds == query_y).float().mean().item()
            val_acc = validate(model, val_dataset, DEVICE, N_WAY, N_SHOT, N_QUERY, episodes=20)
            print(f"Episode {episode+1} | Loss: {loss.item():.4f} | Train Acc: {acc:.4f} | Val Acc: {val_acc:.4f}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), "best_protonet.pth")
                print(f"Best model saved at Episode {episode+1} with Val Acc: {val_acc:.4f}")

if __name__ == "__main__":
    train_protonet()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 190MB/s]
  1%|          | 11/2000 [00:07<34:12,  1.03s/it]

Episode 10 | Loss: 0.0014 | Train Acc: 1.0000 | Val Acc: 0.4667
Best model saved at Episode 10 with Val Acc: 0.4667


  1%|          | 20/2000 [00:13<55:28,  1.68s/it]

Episode 20 | Loss: 1.3259 | Train Acc: 0.6667 | Val Acc: 0.7000
Best model saved at Episode 20 with Val Acc: 0.7000


  2%|▏         | 30/2000 [00:23<1:09:09,  2.11s/it]

Episode 30 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8333
Best model saved at Episode 30 with Val Acc: 0.8333


  2%|▏         | 41/2000 [00:31<46:17,  1.42s/it]  

Episode 40 | Loss: 51.9947 | Train Acc: 0.0000 | Val Acc: 0.5500


  3%|▎         | 51/2000 [00:40<49:50,  1.53s/it]  

Episode 50 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.6000


  3%|▎         | 61/2000 [00:49<44:58,  1.39s/it]  

Episode 60 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.7000


  4%|▎         | 71/2000 [00:55<34:34,  1.08s/it]

Episode 70 | Loss: 0.8471 | Train Acc: 0.6667 | Val Acc: 0.6667


  4%|▍         | 80/2000 [01:04<1:05:45,  2.05s/it]

Episode 80 | Loss: 0.2358 | Train Acc: 0.6667 | Val Acc: 0.8667
Best model saved at Episode 80 with Val Acc: 0.8667


  5%|▍         | 91/2000 [01:12<50:59,  1.60s/it]  

Episode 90 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8000


  5%|▌         | 100/2000 [01:20<58:05,  1.83s/it]

Episode 100 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8333


  6%|▌         | 111/2000 [01:28<44:38,  1.42s/it]  

Episode 110 | Loss: 0.0016 | Train Acc: 1.0000 | Val Acc: 0.8167


  6%|▌         | 121/2000 [01:38<54:48,  1.75s/it]  

Episode 120 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.7333


  6%|▋         | 130/2000 [01:47<1:02:17,  2.00s/it]

Episode 130 | Loss: 0.6738 | Train Acc: 0.3333 | Val Acc: 0.4667


  7%|▋         | 141/2000 [01:55<40:49,  1.32s/it]

Episode 140 | Loss: 0.0005 | Train Acc: 1.0000 | Val Acc: 0.7333


  8%|▊         | 151/2000 [02:02<34:44,  1.13s/it]

Episode 150 | Loss: 1.4603 | Train Acc: 0.0000 | Val Acc: 0.8167


  8%|▊         | 161/2000 [02:08<30:58,  1.01s/it]

Episode 160 | Loss: 0.0513 | Train Acc: 1.0000 | Val Acc: 0.7833


  9%|▊         | 171/2000 [02:15<32:49,  1.08s/it]

Episode 170 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.6500


  9%|▉         | 180/2000 [02:22<46:40,  1.54s/it]

Episode 180 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.6500


 10%|▉         | 190/2000 [02:30<1:00:43,  2.01s/it]

Episode 190 | Loss: 0.0481 | Train Acc: 1.0000 | Val Acc: 0.4167


 10%|█         | 201/2000 [02:38<37:29,  1.25s/it]

Episode 200 | Loss: 2.4278 | Train Acc: 0.6667 | Val Acc: 0.6167


 10%|█         | 210/2000 [02:47<1:09:01,  2.31s/it]

Episode 210 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.6167


 11%|█         | 220/2000 [02:54<47:01,  1.59s/it]

Episode 220 | Loss: 0.5377 | Train Acc: 0.6667 | Val Acc: 0.8833
Best model saved at Episode 220 with Val Acc: 0.8833


 12%|█▏        | 230/2000 [03:04<1:03:54,  2.17s/it]

Episode 230 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.7167


 12%|█▏        | 241/2000 [03:12<34:17,  1.17s/it]

Episode 240 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8500


 13%|█▎        | 251/2000 [03:17<29:44,  1.02s/it]

Episode 250 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8167


 13%|█▎        | 261/2000 [03:25<31:18,  1.08s/it]

Episode 260 | Loss: 0.3270 | Train Acc: 0.6667 | Val Acc: 0.7333


 14%|█▎        | 271/2000 [03:35<47:33,  1.65s/it]  

Episode 270 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.6167


 14%|█▍        | 280/2000 [03:43<41:26,  1.45s/it]

Episode 280 | Loss: 3.5635 | Train Acc: 0.6667 | Val Acc: 0.6000


 14%|█▍        | 290/2000 [03:52<49:34,  1.74s/it]

Episode 290 | Loss: 0.0241 | Train Acc: 1.0000 | Val Acc: 0.9000
Best model saved at Episode 290 with Val Acc: 0.9000


 15%|█▌        | 300/2000 [04:00<52:15,  1.84s/it]

Episode 300 | Loss: 0.4578 | Train Acc: 1.0000 | Val Acc: 0.9167
Best model saved at Episode 300 with Val Acc: 0.9167


 16%|█▌        | 311/2000 [04:07<34:26,  1.22s/it]

Episode 310 | Loss: 0.0321 | Train Acc: 1.0000 | Val Acc: 0.8167


 16%|█▌        | 320/2000 [04:12<23:26,  1.19it/s]

Episode 320 | Loss: 0.2488 | Train Acc: 1.0000 | Val Acc: 0.8667


 17%|█▋        | 331/2000 [04:19<25:43,  1.08it/s]

Episode 330 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.8000


 17%|█▋        | 341/2000 [04:25<26:03,  1.06it/s]

Episode 340 | Loss: 0.4818 | Train Acc: 0.6667 | Val Acc: 0.8500


 18%|█▊        | 350/2000 [04:31<31:50,  1.16s/it]

Episode 350 | Loss: 0.0334 | Train Acc: 1.0000 | Val Acc: 0.8833


 18%|█▊        | 360/2000 [04:38<35:52,  1.31s/it]

Episode 360 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8333


 19%|█▊        | 371/2000 [04:46<32:37,  1.20s/it]

Episode 370 | Loss: 8.0375 | Train Acc: 0.6667 | Val Acc: 0.8667


 19%|█▉        | 380/2000 [04:50<28:43,  1.06s/it]

Episode 380 | Loss: 0.0067 | Train Acc: 1.0000 | Val Acc: 0.8167


 20%|█▉        | 391/2000 [04:59<30:54,  1.15s/it]

Episode 390 | Loss: 0.0706 | Train Acc: 1.0000 | Val Acc: 0.6833


 20%|██        | 401/2000 [05:05<26:16,  1.01it/s]

Episode 400 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8000


 21%|██        | 411/2000 [05:12<28:11,  1.06s/it]

Episode 410 | Loss: 0.0007 | Train Acc: 1.0000 | Val Acc: 0.7667


 21%|██        | 422/2000 [05:18<17:36,  1.49it/s]

Episode 420 | Loss: 0.0079 | Train Acc: 1.0000 | Val Acc: 0.8000


 22%|██▏       | 431/2000 [05:24<28:25,  1.09s/it]

Episode 430 | Loss: 0.0287 | Train Acc: 1.0000 | Val Acc: 0.7833


 22%|██▏       | 441/2000 [05:33<37:08,  1.43s/it]

Episode 440 | Loss: 0.5218 | Train Acc: 1.0000 | Val Acc: 0.7667


 22%|██▎       | 450/2000 [05:39<30:25,  1.18s/it]

Episode 450 | Loss: 0.0135 | Train Acc: 1.0000 | Val Acc: 0.8333


 23%|██▎       | 461/2000 [05:46<23:15,  1.10it/s]

Episode 460 | Loss: 0.0564 | Train Acc: 1.0000 | Val Acc: 0.7500


 24%|██▎       | 472/2000 [05:52<18:40,  1.36it/s]

Episode 470 | Loss: 0.2269 | Train Acc: 1.0000 | Val Acc: 0.7667


 24%|██▍       | 482/2000 [05:59<22:46,  1.11it/s]

Episode 480 | Loss: 0.0540 | Train Acc: 1.0000 | Val Acc: 0.9167


 25%|██▍       | 491/2000 [06:04<25:27,  1.01s/it]

Episode 490 | Loss: 0.1161 | Train Acc: 1.0000 | Val Acc: 0.9167


 25%|██▌       | 501/2000 [06:11<22:49,  1.09it/s]

Episode 500 | Loss: 0.1757 | Train Acc: 1.0000 | Val Acc: 0.8500


 26%|██▌       | 512/2000 [06:18<23:41,  1.05it/s]

Episode 510 | Loss: 0.0039 | Train Acc: 1.0000 | Val Acc: 0.7833


 26%|██▌       | 521/2000 [06:25<27:44,  1.13s/it]

Episode 520 | Loss: 0.2053 | Train Acc: 1.0000 | Val Acc: 0.8500


 27%|██▋       | 532/2000 [06:32<19:10,  1.28it/s]

Episode 530 | Loss: 0.0130 | Train Acc: 1.0000 | Val Acc: 0.8667


 27%|██▋       | 540/2000 [06:40<45:13,  1.86s/it]

Episode 540 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.9167


 28%|██▊       | 550/2000 [06:48<39:21,  1.63s/it]

Episode 550 | Loss: 0.0220 | Train Acc: 1.0000 | Val Acc: 0.9000


 28%|██▊       | 560/2000 [06:54<30:46,  1.28s/it]

Episode 560 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8333


 29%|██▊       | 571/2000 [07:03<29:11,  1.23s/it]

Episode 570 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.9000


 29%|██▉       | 582/2000 [07:09<18:21,  1.29it/s]

Episode 580 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.7667


 30%|██▉       | 590/2000 [07:16<35:20,  1.50s/it]

Episode 590 | Loss: 0.0327 | Train Acc: 1.0000 | Val Acc: 0.9167


 30%|███       | 602/2000 [07:23<19:48,  1.18it/s]

Episode 600 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.8667


 30%|███       | 610/2000 [07:29<24:30,  1.06s/it]

Episode 610 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.7333


 31%|███       | 620/2000 [07:35<27:18,  1.19s/it]

Episode 620 | Loss: 0.6636 | Train Acc: 0.6667 | Val Acc: 0.7833


 32%|███▏      | 630/2000 [07:45<49:13,  2.16s/it]

Episode 630 | Loss: 0.1279 | Train Acc: 1.0000 | Val Acc: 0.7667


 32%|███▏      | 642/2000 [07:51<21:05,  1.07it/s]

Episode 640 | Loss: 0.0034 | Train Acc: 1.0000 | Val Acc: 0.7667


 33%|███▎      | 651/2000 [07:56<17:41,  1.27it/s]

Episode 650 | Loss: 0.0017 | Train Acc: 1.0000 | Val Acc: 0.7833


 33%|███▎      | 662/2000 [08:03<22:33,  1.01s/it]

Episode 660 | Loss: 0.0929 | Train Acc: 1.0000 | Val Acc: 0.8667


 34%|███▎      | 670/2000 [08:11<42:45,  1.93s/it]

Episode 670 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8833


 34%|███▍      | 681/2000 [08:18<19:38,  1.12it/s]

Episode 680 | Loss: 0.0008 | Train Acc: 1.0000 | Val Acc: 0.8833


 34%|███▍      | 690/2000 [08:26<32:50,  1.50s/it]

Episode 690 | Loss: 0.0006 | Train Acc: 1.0000 | Val Acc: 0.8000


 35%|███▌      | 701/2000 [08:34<29:01,  1.34s/it]

Episode 700 | Loss: 0.0024 | Train Acc: 1.0000 | Val Acc: 0.9000


 36%|███▌      | 712/2000 [08:42<17:48,  1.21it/s]

Episode 710 | Loss: 0.0138 | Train Acc: 1.0000 | Val Acc: 0.8333


 36%|███▌      | 720/2000 [08:48<26:13,  1.23s/it]

Episode 720 | Loss: 0.0203 | Train Acc: 1.0000 | Val Acc: 0.8500


 36%|███▋      | 730/2000 [08:54<22:27,  1.06s/it]

Episode 730 | Loss: 0.0695 | Train Acc: 1.0000 | Val Acc: 0.8500


 37%|███▋      | 742/2000 [09:01<16:13,  1.29it/s]

Episode 740 | Loss: 0.0024 | Train Acc: 1.0000 | Val Acc: 0.8000


 38%|███▊      | 751/2000 [09:08<21:00,  1.01s/it]

Episode 750 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.8667


 38%|███▊      | 762/2000 [09:14<16:40,  1.24it/s]

Episode 760 | Loss: 0.0380 | Train Acc: 1.0000 | Val Acc: 0.7667


 39%|███▊      | 771/2000 [09:22<23:43,  1.16s/it]

Episode 770 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8167


 39%|███▉      | 780/2000 [09:30<27:01,  1.33s/it]

Episode 780 | Loss: 0.0452 | Train Acc: 1.0000 | Val Acc: 0.9333
Best model saved at Episode 780 with Val Acc: 0.9333


 40%|███▉      | 791/2000 [09:39<22:35,  1.12s/it]

Episode 790 | Loss: 0.0560 | Train Acc: 1.0000 | Val Acc: 0.8667


 40%|████      | 801/2000 [09:46<26:04,  1.31s/it]

Episode 800 | Loss: 0.0330 | Train Acc: 1.0000 | Val Acc: 0.8333


 40%|████      | 810/2000 [09:53<30:14,  1.52s/it]

Episode 810 | Loss: 0.0072 | Train Acc: 1.0000 | Val Acc: 0.7833


 41%|████      | 820/2000 [10:02<29:49,  1.52s/it]

Episode 820 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.8000


 42%|████▏     | 830/2000 [10:09<22:59,  1.18s/it]

Episode 830 | Loss: 0.0021 | Train Acc: 1.0000 | Val Acc: 0.8167


 42%|████▏     | 842/2000 [10:17<13:46,  1.40it/s]

Episode 840 | Loss: 0.5649 | Train Acc: 0.6667 | Val Acc: 0.7333


 43%|████▎     | 852/2000 [10:23<14:51,  1.29it/s]

Episode 850 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.8500


 43%|████▎     | 861/2000 [10:29<18:20,  1.03it/s]

Episode 860 | Loss: 4.6499 | Train Acc: 0.3333 | Val Acc: 0.8500


 44%|████▎     | 871/2000 [10:35<20:37,  1.10s/it]

Episode 870 | Loss: 0.0146 | Train Acc: 1.0000 | Val Acc: 0.8833


 44%|████▍     | 880/2000 [10:42<28:57,  1.55s/it]

Episode 880 | Loss: 0.0318 | Train Acc: 1.0000 | Val Acc: 0.8333


 44%|████▍     | 890/2000 [10:51<29:50,  1.61s/it]

Episode 890 | Loss: 0.8166 | Train Acc: 0.6667 | Val Acc: 0.8500


 45%|████▌     | 900/2000 [10:59<26:20,  1.44s/it]

Episode 900 | Loss: 0.0163 | Train Acc: 1.0000 | Val Acc: 0.7667


 46%|████▌     | 910/2000 [11:08<26:08,  1.44s/it]

Episode 910 | Loss: 0.1624 | Train Acc: 1.0000 | Val Acc: 0.8167


 46%|████▌     | 920/2000 [11:15<19:24,  1.08s/it]

Episode 920 | Loss: 0.0009 | Train Acc: 1.0000 | Val Acc: 0.8500


 47%|████▋     | 932/2000 [11:22<13:37,  1.31it/s]

Episode 930 | Loss: 0.0036 | Train Acc: 1.0000 | Val Acc: 0.9167


 47%|████▋     | 940/2000 [11:29<23:45,  1.35s/it]

Episode 940 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8333


 48%|████▊     | 950/2000 [11:36<16:48,  1.04it/s]

Episode 950 | Loss: 0.0659 | Train Acc: 1.0000 | Val Acc: 0.8500


 48%|████▊     | 960/2000 [11:42<19:39,  1.13s/it]

Episode 960 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.7167


 49%|████▊     | 972/2000 [11:49<12:32,  1.37it/s]

Episode 970 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.7333


 49%|████▉     | 980/2000 [11:55<17:32,  1.03s/it]

Episode 980 | Loss: 0.3717 | Train Acc: 0.6667 | Val Acc: 0.7000


 50%|████▉     | 990/2000 [12:01<20:50,  1.24s/it]

Episode 990 | Loss: 0.0863 | Train Acc: 1.0000 | Val Acc: 0.7000


 50%|█████     | 1000/2000 [12:05<10:52,  1.53it/s]

Episode 1000 | Loss: 0.0100 | Train Acc: 1.0000 | Val Acc: 0.8667


 50%|█████     | 1010/2000 [12:11<15:01,  1.10it/s]

Episode 1010 | Loss: 4.7124 | Train Acc: 0.3333 | Val Acc: 0.8000


 51%|█████     | 1020/2000 [12:17<16:13,  1.01it/s]

Episode 1020 | Loss: 0.1224 | Train Acc: 1.0000 | Val Acc: 0.8000


 52%|█████▏    | 1030/2000 [12:24<20:39,  1.28s/it]

Episode 1030 | Loss: 0.0356 | Train Acc: 1.0000 | Val Acc: 0.8500


 52%|█████▏    | 1040/2000 [12:32<21:12,  1.33s/it]

Episode 1040 | Loss: 0.0140 | Train Acc: 1.0000 | Val Acc: 0.8500


 53%|█████▎    | 1052/2000 [12:39<13:57,  1.13it/s]

Episode 1050 | Loss: 0.0062 | Train Acc: 1.0000 | Val Acc: 0.9167


 53%|█████▎    | 1062/2000 [12:47<15:46,  1.01s/it]

Episode 1060 | Loss: 0.0558 | Train Acc: 1.0000 | Val Acc: 0.8667


 54%|█████▎    | 1072/2000 [12:57<16:12,  1.05s/it]

Episode 1070 | Loss: 0.0060 | Train Acc: 1.0000 | Val Acc: 0.8667


 54%|█████▍    | 1082/2000 [13:03<11:20,  1.35it/s]

Episode 1080 | Loss: 0.0059 | Train Acc: 1.0000 | Val Acc: 0.8500


 55%|█████▍    | 1090/2000 [13:10<23:15,  1.53s/it]

Episode 1090 | Loss: 0.0057 | Train Acc: 1.0000 | Val Acc: 0.8500


 55%|█████▌    | 1102/2000 [13:17<12:46,  1.17it/s]

Episode 1100 | Loss: 0.0204 | Train Acc: 1.0000 | Val Acc: 0.8667


 56%|█████▌    | 1112/2000 [13:24<12:39,  1.17it/s]

Episode 1110 | Loss: 0.0756 | Train Acc: 1.0000 | Val Acc: 0.7167


 56%|█████▌    | 1120/2000 [13:29<14:32,  1.01it/s]

Episode 1120 | Loss: 1.7776 | Train Acc: 0.3333 | Val Acc: 0.6167


 57%|█████▋    | 1131/2000 [13:36<13:33,  1.07it/s]

Episode 1130 | Loss: 0.0008 | Train Acc: 1.0000 | Val Acc: 0.8167


 57%|█████▋    | 1140/2000 [13:42<19:02,  1.33s/it]

Episode 1140 | Loss: 0.0995 | Train Acc: 1.0000 | Val Acc: 0.9167


 58%|█████▊    | 1152/2000 [13:50<11:56,  1.18it/s]

Episode 1150 | Loss: 0.0290 | Train Acc: 1.0000 | Val Acc: 0.9333


 58%|█████▊    | 1162/2000 [13:54<07:21,  1.90it/s]

Episode 1160 | Loss: 0.7149 | Train Acc: 0.3333 | Val Acc: 0.8833


 59%|█████▊    | 1171/2000 [14:00<12:20,  1.12it/s]

Episode 1170 | Loss: 0.3375 | Train Acc: 0.6667 | Val Acc: 0.9000


 59%|█████▉    | 1180/2000 [14:07<17:34,  1.29s/it]

Episode 1180 | Loss: 0.0331 | Train Acc: 1.0000 | Val Acc: 0.9500
Best model saved at Episode 1180 with Val Acc: 0.9500


 60%|█████▉    | 1190/2000 [14:14<17:41,  1.31s/it]

Episode 1190 | Loss: 0.1881 | Train Acc: 1.0000 | Val Acc: 0.8833


 60%|██████    | 1202/2000 [14:22<13:04,  1.02it/s]

Episode 1200 | Loss: 0.7260 | Train Acc: 0.6667 | Val Acc: 0.9500


 60%|██████    | 1210/2000 [14:28<15:49,  1.20s/it]

Episode 1210 | Loss: 0.2038 | Train Acc: 1.0000 | Val Acc: 0.9667
Best model saved at Episode 1210 with Val Acc: 0.9667


 61%|██████    | 1220/2000 [14:34<10:29,  1.24it/s]

Episode 1220 | Loss: 0.1124 | Train Acc: 1.0000 | Val Acc: 0.9167


 62%|██████▏   | 1230/2000 [14:40<14:04,  1.10s/it]

Episode 1230 | Loss: 0.0173 | Train Acc: 1.0000 | Val Acc: 0.8833


 62%|██████▏   | 1241/2000 [14:48<13:08,  1.04s/it]

Episode 1240 | Loss: 0.7838 | Train Acc: 0.6667 | Val Acc: 0.8667


 63%|██████▎   | 1252/2000 [14:55<08:45,  1.42it/s]

Episode 1250 | Loss: 0.0011 | Train Acc: 1.0000 | Val Acc: 0.8500


 63%|██████▎   | 1262/2000 [15:02<12:00,  1.02it/s]

Episode 1260 | Loss: 0.0012 | Train Acc: 1.0000 | Val Acc: 0.8333


 64%|██████▎   | 1270/2000 [15:11<21:52,  1.80s/it]

Episode 1270 | Loss: 0.0092 | Train Acc: 1.0000 | Val Acc: 0.9167


 64%|██████▍   | 1280/2000 [15:21<22:21,  1.86s/it]

Episode 1280 | Loss: 0.2092 | Train Acc: 1.0000 | Val Acc: 0.9667


 64%|██████▍   | 1290/2000 [15:26<12:51,  1.09s/it]

Episode 1290 | Loss: 0.0069 | Train Acc: 1.0000 | Val Acc: 0.9167


 65%|██████▌   | 1300/2000 [15:32<14:59,  1.29s/it]

Episode 1300 | Loss: 0.0514 | Train Acc: 1.0000 | Val Acc: 0.9667


 66%|██████▌   | 1311/2000 [15:41<14:13,  1.24s/it]

Episode 1310 | Loss: 0.0006 | Train Acc: 1.0000 | Val Acc: 0.9167


 66%|██████▌   | 1322/2000 [15:49<09:58,  1.13it/s]

Episode 1320 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.9167


 67%|██████▋   | 1331/2000 [15:53<08:02,  1.39it/s]

Episode 1330 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.9333


 67%|██████▋   | 1342/2000 [16:00<08:50,  1.24it/s]

Episode 1340 | Loss: 0.0004 | Train Acc: 1.0000 | Val Acc: 0.8833


 68%|██████▊   | 1352/2000 [16:06<09:13,  1.17it/s]

Episode 1350 | Loss: 0.1947 | Train Acc: 1.0000 | Val Acc: 0.9000


 68%|██████▊   | 1361/2000 [16:13<11:15,  1.06s/it]

Episode 1360 | Loss: 0.0119 | Train Acc: 1.0000 | Val Acc: 0.9000


 69%|██████▊   | 1372/2000 [16:18<07:28,  1.40it/s]

Episode 1370 | Loss: 0.0401 | Train Acc: 1.0000 | Val Acc: 0.9500


 69%|██████▉   | 1380/2000 [16:23<07:59,  1.29it/s]

Episode 1380 | Loss: 0.0170 | Train Acc: 1.0000 | Val Acc: 0.9000


 70%|██████▉   | 1392/2000 [16:30<07:55,  1.28it/s]

Episode 1390 | Loss: 0.4385 | Train Acc: 0.6667 | Val Acc: 0.8833


 70%|███████   | 1400/2000 [16:35<09:17,  1.08it/s]

Episode 1400 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.9167


 70%|███████   | 1410/2000 [16:39<05:51,  1.68it/s]

Episode 1410 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.8000


 71%|███████   | 1420/2000 [16:45<08:22,  1.16it/s]

Episode 1420 | Loss: 0.0016 | Train Acc: 1.0000 | Val Acc: 0.9000


 72%|███████▏  | 1430/2000 [16:52<10:57,  1.15s/it]

Episode 1430 | Loss: 0.1247 | Train Acc: 1.0000 | Val Acc: 0.8833


 72%|███████▏  | 1441/2000 [16:59<08:42,  1.07it/s]

Episode 1440 | Loss: 0.0009 | Train Acc: 1.0000 | Val Acc: 0.8667


 73%|███████▎  | 1452/2000 [17:07<09:28,  1.04s/it]

Episode 1450 | Loss: 0.0133 | Train Acc: 1.0000 | Val Acc: 0.9500


 73%|███████▎  | 1460/2000 [17:14<10:23,  1.15s/it]

Episode 1460 | Loss: 0.0003 | Train Acc: 1.0000 | Val Acc: 0.7833


 74%|███████▎  | 1470/2000 [17:21<10:38,  1.20s/it]

Episode 1470 | Loss: 3.3260 | Train Acc: 0.6667 | Val Acc: 0.8167


 74%|███████▍  | 1482/2000 [17:26<05:22,  1.61it/s]

Episode 1480 | Loss: 0.0007 | Train Acc: 1.0000 | Val Acc: 0.9000


 74%|███████▍  | 1490/2000 [17:30<05:32,  1.53it/s]

Episode 1490 | Loss: 0.0495 | Train Acc: 1.0000 | Val Acc: 0.8667


 75%|███████▌  | 1500/2000 [17:36<09:31,  1.14s/it]

Episode 1500 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.9500


 76%|███████▌  | 1510/2000 [17:42<08:12,  1.01s/it]

Episode 1510 | Loss: 0.8954 | Train Acc: 0.6667 | Val Acc: 0.9000


 76%|███████▌  | 1521/2000 [17:52<10:28,  1.31s/it]

Episode 1520 | Loss: 0.0186 | Train Acc: 1.0000 | Val Acc: 0.9667


 77%|███████▋  | 1532/2000 [17:58<05:56,  1.31it/s]

Episode 1530 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9333


 77%|███████▋  | 1540/2000 [18:06<09:38,  1.26s/it]

Episode 1540 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9500


 78%|███████▊  | 1552/2000 [18:11<04:09,  1.80it/s]

Episode 1550 | Loss: 1.2659 | Train Acc: 0.3333 | Val Acc: 0.9167


 78%|███████▊  | 1560/2000 [18:17<07:24,  1.01s/it]

Episode 1560 | Loss: 0.0603 | Train Acc: 1.0000 | Val Acc: 0.9000


 79%|███████▊  | 1572/2000 [18:26<08:07,  1.14s/it]

Episode 1570 | Loss: 0.0831 | Train Acc: 1.0000 | Val Acc: 0.9000


 79%|███████▉  | 1581/2000 [18:33<07:55,  1.14s/it]

Episode 1580 | Loss: 0.6886 | Train Acc: 0.3333 | Val Acc: 0.9667


 80%|███████▉  | 1590/2000 [18:40<07:12,  1.05s/it]

Episode 1590 | Loss: 0.0170 | Train Acc: 1.0000 | Val Acc: 0.9833
Best model saved at Episode 1590 with Val Acc: 0.9833


 80%|████████  | 1602/2000 [18:47<05:26,  1.22it/s]

Episode 1600 | Loss: 0.0906 | Train Acc: 1.0000 | Val Acc: 0.9500


 81%|████████  | 1611/2000 [18:55<07:35,  1.17s/it]

Episode 1610 | Loss: 0.1285 | Train Acc: 1.0000 | Val Acc: 0.9333


 81%|████████  | 1622/2000 [19:02<05:24,  1.16it/s]

Episode 1620 | Loss: 0.0037 | Train Acc: 1.0000 | Val Acc: 0.9000


 82%|████████▏ | 1632/2000 [19:10<06:21,  1.04s/it]

Episode 1630 | Loss: 0.0055 | Train Acc: 1.0000 | Val Acc: 0.9667


 82%|████████▏ | 1640/2000 [19:17<08:17,  1.38s/it]

Episode 1640 | Loss: 0.1502 | Train Acc: 1.0000 | Val Acc: 0.9500


 82%|████████▎ | 1650/2000 [19:25<09:05,  1.56s/it]

Episode 1650 | Loss: 0.0210 | Train Acc: 1.0000 | Val Acc: 0.9000


 83%|████████▎ | 1662/2000 [19:32<04:53,  1.15it/s]

Episode 1660 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9000


 84%|████████▎ | 1670/2000 [19:38<06:21,  1.16s/it]

Episode 1670 | Loss: 0.0005 | Train Acc: 1.0000 | Val Acc: 0.9333


 84%|████████▍ | 1681/2000 [19:44<04:05,  1.30it/s]

Episode 1680 | Loss: 0.0357 | Train Acc: 1.0000 | Val Acc: 0.9333


 85%|████████▍ | 1692/2000 [19:51<04:15,  1.21it/s]

Episode 1690 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.9000


 85%|████████▌ | 1700/2000 [19:56<04:45,  1.05it/s]

Episode 1700 | Loss: 0.0972 | Train Acc: 1.0000 | Val Acc: 0.9500


 86%|████████▌ | 1710/2000 [20:03<06:23,  1.32s/it]

Episode 1710 | Loss: 0.0033 | Train Acc: 1.0000 | Val Acc: 0.9333


 86%|████████▌ | 1722/2000 [20:10<03:32,  1.31it/s]

Episode 1720 | Loss: 0.0015 | Train Acc: 1.0000 | Val Acc: 0.9500


 86%|████████▋ | 1730/2000 [20:16<04:33,  1.01s/it]

Episode 1730 | Loss: 0.0063 | Train Acc: 1.0000 | Val Acc: 0.9167


 87%|████████▋ | 1742/2000 [20:25<04:08,  1.04it/s]

Episode 1740 | Loss: 0.4060 | Train Acc: 0.6667 | Val Acc: 0.9000


 88%|████████▊ | 1752/2000 [20:31<03:08,  1.32it/s]

Episode 1750 | Loss: 0.0095 | Train Acc: 1.0000 | Val Acc: 0.9000


 88%|████████▊ | 1761/2000 [20:36<02:49,  1.41it/s]

Episode 1760 | Loss: 0.5535 | Train Acc: 0.6667 | Val Acc: 0.8500


 89%|████████▊ | 1771/2000 [20:42<02:41,  1.42it/s]

Episode 1770 | Loss: 0.4126 | Train Acc: 1.0000 | Val Acc: 0.8667


 89%|████████▉ | 1780/2000 [20:49<05:11,  1.41s/it]

Episode 1780 | Loss: 0.3982 | Train Acc: 1.0000 | Val Acc: 0.7500


 90%|████████▉ | 1791/2000 [20:55<03:50,  1.10s/it]

Episode 1790 | Loss: 0.0001 | Train Acc: 1.0000 | Val Acc: 0.8333


 90%|█████████ | 1800/2000 [21:02<04:06,  1.23s/it]

Episode 1800 | Loss: 0.2042 | Train Acc: 1.0000 | Val Acc: 0.9333


 90%|█████████ | 1810/2000 [21:07<02:33,  1.24it/s]

Episode 1810 | Loss: 0.0467 | Train Acc: 1.0000 | Val Acc: 0.8333


 91%|█████████ | 1820/2000 [21:14<03:34,  1.19s/it]

Episode 1820 | Loss: 0.8463 | Train Acc: 0.3333 | Val Acc: 0.9333


 92%|█████████▏| 1832/2000 [21:22<02:13,  1.25it/s]

Episode 1830 | Loss: 0.0002 | Train Acc: 1.0000 | Val Acc: 0.9333


 92%|█████████▏| 1841/2000 [21:29<03:05,  1.17s/it]

Episode 1840 | Loss: 0.3202 | Train Acc: 0.6667 | Val Acc: 0.9000


 92%|█████████▎| 1850/2000 [21:37<03:25,  1.37s/it]

Episode 1850 | Loss: 0.0568 | Train Acc: 1.0000 | Val Acc: 0.9667


 93%|█████████▎| 1860/2000 [21:44<03:40,  1.58s/it]

Episode 1860 | Loss: 0.0276 | Train Acc: 1.0000 | Val Acc: 0.9333


 94%|█████████▎| 1870/2000 [21:51<03:12,  1.48s/it]

Episode 1870 | Loss: 0.2967 | Train Acc: 0.6667 | Val Acc: 0.9833


 94%|█████████▍| 1882/2000 [21:57<01:23,  1.42it/s]

Episode 1880 | Loss: 0.0898 | Train Acc: 1.0000 | Val Acc: 0.8833


 95%|█████████▍| 1891/2000 [22:06<02:29,  1.37s/it]

Episode 1890 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9333


 95%|█████████▌| 1902/2000 [22:14<01:47,  1.09s/it]

Episode 1900 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9167


 96%|█████████▌| 1912/2000 [22:22<01:25,  1.02it/s]

Episode 1910 | Loss: 0.0067 | Train Acc: 1.0000 | Val Acc: 0.9500


 96%|█████████▌| 1920/2000 [22:27<01:21,  1.02s/it]

Episode 1920 | Loss: 0.0685 | Train Acc: 1.0000 | Val Acc: 0.9667


 97%|█████████▋| 1931/2000 [22:33<00:53,  1.28it/s]

Episode 1930 | Loss: 0.0009 | Train Acc: 1.0000 | Val Acc: 0.8167


 97%|█████████▋| 1940/2000 [22:38<00:45,  1.31it/s]

Episode 1940 | Loss: 0.1071 | Train Acc: 1.0000 | Val Acc: 0.9000


 98%|█████████▊| 1950/2000 [22:46<01:06,  1.33s/it]

Episode 1950 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9333


 98%|█████████▊| 1960/2000 [22:52<00:37,  1.07it/s]

Episode 1960 | Loss: 0.0004 | Train Acc: 1.0000 | Val Acc: 0.9167


 99%|█████████▊| 1972/2000 [22:57<00:14,  1.97it/s]

Episode 1970 | Loss: 0.0216 | Train Acc: 1.0000 | Val Acc: 0.9000


 99%|█████████▉| 1980/2000 [23:03<00:22,  1.10s/it]

Episode 1980 | Loss: 0.0000 | Train Acc: 1.0000 | Val Acc: 0.9000


100%|█████████▉| 1991/2000 [23:10<00:09,  1.02s/it]

Episode 1990 | Loss: 8.4743 | Train Acc: 0.3333 | Val Acc: 0.9500


100%|██████████| 2000/2000 [23:14<00:00,  1.43it/s]

Episode 2000 | Loss: 0.0002 | Train Acc: 1.0000 | Val Acc: 0.8833



