In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Question 2: Classification Using CNN

### Data Loading and Preprocessing

In [3]:
def load_mnist_data(path):
    splits = ['train', 'val', 'test']
    data = {'train': [], 'val': [], 'test': []}
    labels = {'train': [], 'val': [], 'test': []}
    for split in splits:
        split_path = os.path.join(path, split)
        for label in os.listdir(split_path):
            label_path = os.path.join(split_path, label)
            if (len(label)==1 and int(label) == 0):
                cur_label = 0
            else:
                cur_label = len(label)
            if os.path.isdir(label_path):
                for image_name in os.listdir(label_path):
                    image_path = os.path.join(label_path, image_name)
                    try:
                        image = Image.open(image_path).convert('L')
                        image_array = np.array(image)
                        data[split].append(image_array)
                        labels[split].append(cur_label)
                    except Exception as e:
                        print(f"Error loading image {image_name}: {e}")
    
    return data['train'], labels['train'], data['val'], labels['val'], data['test'], labels['test']



In [4]:
data_path = "./../../data/external/double_mnist"

train_data, train_labels, val_data, val_labels, test_data, test_labels = load_mnist_data(data_path)

In [5]:
class MultiMNISTDataset(Dataset):
    def __init__(self, images, labels, task = 'classification', transform=None):
        self.images = images
        self.labels = labels
        self.task = task
        self.transform = transform
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = torch.tensor(image, dtype=torch.float32) / 255.0
        image = image.unsqueeze(0)

        if self.transform:
            image = self.transform(image)
            
        if (self.task == 'classification'):
            label = torch.tensor(label, dtype=torch.long).to(device)
        elif (self.task == 'regression'):
            label = torch.tensor(label, dtype=torch.float32).to(device).unsqueeze(0)
        return image, label


### Implement the CNN Class

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

class CNN(nn.Module):
    def __init__(self, task='classification', num_classes=10):
        super(CNN, self).__init__()
        
        self.task = task
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        
        self.fc1 = None
        self.fc2 = nn.Linear(128, 64)
        
        if task == 'classification':
            self.fc3 = nn.Linear(64, num_classes)
        elif task == 'regression':
            self.fc3 = nn.Linear(64, 1)

    def _initialize_fc(self, input_shape, device):
        dummy_input = torch.zeros(1, *input_shape).to(device)
        with torch.no_grad():
            output = self._forward_conv(dummy_input)
        flattened_size = output.view(-1).shape[0]
        self.fc1 = nn.Linear(flattened_size, 128).to(device)

    def _forward_conv(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        return x

    def forward(self, x):
        if self.fc1 is None:
            self._initialize_fc(x.shape[1:], x.device)

        x = self._forward_conv(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            y_pred = self.forward(x)
        return y_pred

    def get_accuracy(self, y_pred, y_true):
        if self.task == 'classification':
            y_pred = torch.argmax(y_pred, dim=1)
            return (y_pred == y_true).float().mean()
        elif self.task == 'regression':
            y_pred = torch.round(y_pred)
            return (y_pred == y_true).float().mean()

    def loss(self, y_pred, y_true):
        if self.task == 'classification':
            return F.cross_entropy(y_pred, y_true)
        elif self.task == 'regression':
            return F.mse_loss(y_pred, y_true)

    def train_model(self, optimizer, train_loader, val_loader, num_epochs=10, device='cpu'):
        self.to(device)
        for epoch in range(num_epochs):
            self.train()
            train_progress = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
            
            for i, (x, y) in train_progress:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = self(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                optimizer.step()
                train_progress.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            self.eval()
            total_accuracy = 0
            total_loss = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    y_pred = self(x)
                    total_loss += self.loss(y_pred, y).item()
                    total_accuracy += self.get_accuracy(y_pred, y).item()
                    
                avg_loss = total_loss / len(val_loader)
                accuracy = total_accuracy / len(val_loader)
                print(f"Epoch {epoch + 1}, Validation Accuracy: {accuracy*100:.2f}%, Validation Loss: {avg_loss:.6f}")


In [7]:
train_data_classification = MultiMNISTDataset(train_data, train_labels)
val_data_classification = MultiMNISTDataset(val_data, val_labels)
test_data_classification = MultiMNISTDataset(test_data, test_labels)

train_loader_classification = DataLoader(train_data_classification, batch_size=32, shuffle=True)
val_loader_classification = DataLoader(val_data_classification, batch_size=32, shuffle=False)
test_loader_classification = DataLoader(test_data_classification, batch_size=32, shuffle=False)

In [8]:
model_single_classification = CNN(task='classification', num_classes=4).to(device)
optimizer = torch.optim.Adam(model_single_classification.parameters(), lr=0.001)
model_single_classification.train_model(optimizer, train_loader_classification, val_loader_classification, num_epochs=20, device=device)

Epoch 1/20: 100%|██████████| 394/394 [00:09<00:00, 42.38it/s, Loss=0.0207]


Epoch 1, Validation Accuracy: 97.96%, Validation Loss: 0.051099


Epoch 2/20: 100%|██████████| 394/394 [00:08<00:00, 48.43it/s, Loss=0.0132]


Epoch 2, Validation Accuracy: 98.70%, Validation Loss: 0.044255


Epoch 3/20: 100%|██████████| 394/394 [00:07<00:00, 54.48it/s, Loss=0.0013]


Epoch 3, Validation Accuracy: 99.66%, Validation Loss: 0.013400


Epoch 4/20: 100%|██████████| 394/394 [00:06<00:00, 62.06it/s, Loss=0.0015]


Epoch 4, Validation Accuracy: 99.66%, Validation Loss: 0.011013


Epoch 5/20: 100%|██████████| 394/394 [00:06<00:00, 62.19it/s, Loss=0.0026]


Epoch 5, Validation Accuracy: 99.50%, Validation Loss: 0.010241


Epoch 6/20: 100%|██████████| 394/394 [00:06<00:00, 61.62it/s, Loss=0.0007]


Epoch 6, Validation Accuracy: 99.80%, Validation Loss: 0.006724


Epoch 7/20: 100%|██████████| 394/394 [00:06<00:00, 62.66it/s, Loss=0.0004]


Epoch 7, Validation Accuracy: 99.80%, Validation Loss: 0.006226


Epoch 8/20: 100%|██████████| 394/394 [00:06<00:00, 61.80it/s, Loss=0.0000]


Epoch 8, Validation Accuracy: 99.83%, Validation Loss: 0.006409


Epoch 9/20: 100%|██████████| 394/394 [00:06<00:00, 61.34it/s, Loss=0.0001]


Epoch 9, Validation Accuracy: 99.97%, Validation Loss: 0.003287


Epoch 10/20: 100%|██████████| 394/394 [00:06<00:00, 61.19it/s, Loss=0.0009]


Epoch 10, Validation Accuracy: 98.05%, Validation Loss: 0.044844


Epoch 11/20: 100%|██████████| 394/394 [00:06<00:00, 60.21it/s, Loss=0.0002]


Epoch 11, Validation Accuracy: 99.90%, Validation Loss: 0.002566


Epoch 12/20: 100%|██████████| 394/394 [00:07<00:00, 54.94it/s, Loss=0.0000]


Epoch 12, Validation Accuracy: 99.83%, Validation Loss: 0.005786


Epoch 13/20: 100%|██████████| 394/394 [00:06<00:00, 56.88it/s, Loss=0.0000]


Epoch 13, Validation Accuracy: 98.76%, Validation Loss: 0.032227


Epoch 14/20: 100%|██████████| 394/394 [00:06<00:00, 60.70it/s, Loss=0.0000]


Epoch 14, Validation Accuracy: 99.93%, Validation Loss: 0.001654


Epoch 15/20: 100%|██████████| 394/394 [00:06<00:00, 60.47it/s, Loss=0.0000]


Epoch 15, Validation Accuracy: 99.93%, Validation Loss: 0.001927


Epoch 16/20: 100%|██████████| 394/394 [00:05<00:00, 70.21it/s, Loss=0.0000]


Epoch 16, Validation Accuracy: 99.97%, Validation Loss: 0.001302


Epoch 17/20: 100%|██████████| 394/394 [00:04<00:00, 81.45it/s, Loss=0.0000]


Epoch 17, Validation Accuracy: 99.97%, Validation Loss: 0.001187


Epoch 18/20: 100%|██████████| 394/394 [00:06<00:00, 63.59it/s, Loss=0.0000]


Epoch 18, Validation Accuracy: 99.97%, Validation Loss: 0.001103


Epoch 19/20: 100%|██████████| 394/394 [00:06<00:00, 63.04it/s, Loss=0.0000]


Epoch 19, Validation Accuracy: 99.97%, Validation Loss: 0.001050


Epoch 20/20: 100%|██████████| 394/394 [00:06<00:00, 63.16it/s, Loss=0.0000]


Epoch 20, Validation Accuracy: 99.97%, Validation Loss: 0.001115


In [9]:
train_data_regression = MultiMNISTDataset(train_data, train_labels, task='regression')
val_data_regression = MultiMNISTDataset(val_data, val_labels, task='regression')
test_data_regression = MultiMNISTDataset(test_data, test_labels, task='regression')

train_loader_regression = DataLoader(train_data_regression, batch_size=32, shuffle=True)
val_loader_regression = DataLoader(val_data_regression, batch_size=32, shuffle=False)
test_loader_regression = DataLoader(test_data_regression, batch_size=32, shuffle=False)

In [10]:
model_single_regression = CNN(task='regression').to(device)
optimizer = torch.optim.Adam(model_single_regression.parameters(), lr=0.001)
model_single_regression.train_model(optimizer, train_loader_regression, val_loader_regression, num_epochs=20, device=device)

Epoch 1/20: 100%|██████████| 394/394 [00:08<00:00, 48.71it/s, Loss=0.0640]


Epoch 1, Validation Accuracy: 96.13%, Validation Loss: 0.054337


Epoch 2/20: 100%|██████████| 394/394 [00:07<00:00, 51.71it/s, Loss=0.0397]


Epoch 2, Validation Accuracy: 98.34%, Validation Loss: 0.033815


Epoch 3/20: 100%|██████████| 394/394 [00:06<00:00, 58.17it/s, Loss=0.0321]


Epoch 3, Validation Accuracy: 99.24%, Validation Loss: 0.024840


Epoch 4/20: 100%|██████████| 394/394 [00:07<00:00, 51.75it/s, Loss=0.0235]


Epoch 4, Validation Accuracy: 99.73%, Validation Loss: 0.020179


Epoch 5/20: 100%|██████████| 394/394 [00:06<00:00, 58.85it/s, Loss=0.0159]


Epoch 5, Validation Accuracy: 99.83%, Validation Loss: 0.016153


Epoch 6/20: 100%|██████████| 394/394 [00:06<00:00, 58.24it/s, Loss=0.0220]


Epoch 6, Validation Accuracy: 99.87%, Validation Loss: 0.014123


Epoch 7/20: 100%|██████████| 394/394 [00:07<00:00, 51.58it/s, Loss=0.0139]


Epoch 7, Validation Accuracy: 99.93%, Validation Loss: 0.016783


Epoch 8/20: 100%|██████████| 394/394 [00:06<00:00, 57.96it/s, Loss=0.0068]


Epoch 8, Validation Accuracy: 99.97%, Validation Loss: 0.012877


Epoch 9/20: 100%|██████████| 394/394 [00:06<00:00, 57.34it/s, Loss=0.0054]


Epoch 9, Validation Accuracy: 99.97%, Validation Loss: 0.009824


Epoch 10/20: 100%|██████████| 394/394 [00:07<00:00, 52.11it/s, Loss=0.0071]


Epoch 10, Validation Accuracy: 99.93%, Validation Loss: 0.009601


Epoch 11/20: 100%|██████████| 394/394 [00:06<00:00, 57.75it/s, Loss=0.0069]


Epoch 11, Validation Accuracy: 100.00%, Validation Loss: 0.008282


Epoch 12/20: 100%|██████████| 394/394 [00:06<00:00, 57.54it/s, Loss=0.0081]


Epoch 12, Validation Accuracy: 100.00%, Validation Loss: 0.007918


Epoch 13/20: 100%|██████████| 394/394 [00:07<00:00, 51.67it/s, Loss=0.0037]


Epoch 13, Validation Accuracy: 100.00%, Validation Loss: 0.007008


Epoch 14/20: 100%|██████████| 394/394 [00:07<00:00, 54.95it/s, Loss=0.0071]


Epoch 14, Validation Accuracy: 100.00%, Validation Loss: 0.007886


Epoch 15/20: 100%|██████████| 394/394 [00:06<00:00, 57.73it/s, Loss=0.0041]


Epoch 15, Validation Accuracy: 100.00%, Validation Loss: 0.006777


Epoch 16/20: 100%|██████████| 394/394 [00:07<00:00, 53.90it/s, Loss=0.0038]


Epoch 16, Validation Accuracy: 100.00%, Validation Loss: 0.006428


Epoch 17/20: 100%|██████████| 394/394 [00:07<00:00, 54.74it/s, Loss=0.0046]


Epoch 17, Validation Accuracy: 100.00%, Validation Loss: 0.006454


Epoch 18/20: 100%|██████████| 394/394 [00:06<00:00, 57.13it/s, Loss=0.0052]


Epoch 18, Validation Accuracy: 100.00%, Validation Loss: 0.005759


Epoch 19/20: 100%|██████████| 394/394 [00:07<00:00, 53.15it/s, Loss=0.0026]


Epoch 19, Validation Accuracy: 100.00%, Validation Loss: 0.005580


Epoch 20/20: 100%|██████████| 394/394 [00:07<00:00, 55.93it/s, Loss=0.0027]


Epoch 20, Validation Accuracy: 100.00%, Validation Loss: 0.005763


In [11]:
def load_mnist_data_multilabel(path):
    splits = ['train', 'val', 'test']
    data = {'train': [], 'val': [], 'test': []}
    labels = {'train': [], 'val': [], 'test': []}
    for split in splits:
        split_path = os.path.join(path, split)
        for label in os.listdir(split_path):
            label_path = os.path.join(split_path, label)
            if (len(label) == 1 and int(label) == 0):
                cur_label = []
            else:
                cur_label = [int(i) for i in label]
            if os.path.isdir(label_path):
                for image_name in os.listdir(label_path):
                    image_path = os.path.join(label_path, image_name)
                    try:
                        image = Image.open(image_path).convert('L')
                        image_array = np.array(image)
                        data[split].append(image_array)
                        labels[split].append(cur_label)
                    except Exception as e:
                        print(f"Error loading image {image_name}: {e}")
    
    return data['train'], labels['train'], data['val'], labels['val'], data['test'], labels['test']

In [12]:
def get_one_hot(y):
    one_hot = np.zeros(33)
    if len(y) == 0:
        y.append(10)
        y.append(10)
        y.append(10)
    elif len(y) == 1:
        y.append(10)
        y.append(10)
    elif len(y) == 2:
        y.append(10)
    for i in range(len(y)):
        one_hot[y[i]+10*i+i] = 1
    return one_hot

In [13]:
class MultiMNISTDataset_mulitlabel(Dataset):
    def __init__(self, images, labels, task='multilabel_classification', transform=None):
        self.images = images
        self.labels = labels
        self.task = task
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = torch.tensor(image, dtype=torch.float32) / 255.0
        image = image.unsqueeze(0)

        if self.transform:
            image = self.transform(image)

        if self.task == 'multilabel_classification':
            label = torch.tensor(label, dtype=torch.float32).to(device)
        else:
            label = torch.tensor(label, dtype=torch.float32).to(device).unsqueeze(0)
        return image, label


In [14]:
class CNN_Multilabel(nn.Module):
    def __init__(self, task='multilabel_classification', num_classes=33):
        super(CNN_Multilabel, self).__init__()
        
        self.task = task
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        
        self.fc1 = None
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def _initialize_fc(self, input_shape, device):
        dummy_input = torch.zeros(1, *input_shape).to(device)
        with torch.no_grad():
            output = self._forward_conv(dummy_input)
        flattened_size = output.view(-1).shape[0]
        self.fc1 = nn.Linear(flattened_size, 128).to(device)

    def _forward_conv(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        return x

    def forward(self, x):
        if self.fc1 is None:
            self._initialize_fc(x.shape[1:], x.device)

        x = self._forward_conv(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def get_accuracy(self, y_pred, y_true):
        if self.task == 'multilabel_classification':
            y_pred_labels = self.convert_to_labels(y_pred)
            y_true_labels = self.convert_to_labels(y_true)
            accuracy = 0
            for i in range(len(y_pred_labels)):
                flag = 1
                if len(y_pred_labels[i]) == len(y_true_labels[i]):
                    for j in range(len(y_pred_labels[i])):
                        if y_pred_labels[i][j] != y_true_labels[i][j]:
                            flag = 0
                            break
                else:
                    flag = 0
                accuracy += flag
            return accuracy/len(y_pred_labels)
        else:
            return F.mse_loss(y_pred, y_true)


    def loss(self, y_pred, y_true):
        if self.task == 'multilabel_classification':
            criterion = nn.CrossEntropyLoss()
            loss = 0
            for i in range(3):
                start = i * 11
                end = (i + 1) * 11
                target_idx = torch.argmax(y_true[:, start:end], dim=1)
                loss += criterion(y_pred[:, start:end], target_idx)
            return loss
        else:
            return F.mse_loss(y_pred, y_true)
            
    def convert_to_labels(self, y_pred):
        labels = []
        batch_size = y_pred.shape[0]
        for i in range(batch_size):
            label = []
            for j in range(3):
                start = j * 11
                end = (j + 1) * 11
                cur_label = torch.argmax(y_pred[i, start:end]).item()
                if cur_label != 10:
                    label.append(cur_label)
                else:
                    break
            labels.append(label)
        return labels
    
    def train_model(self, optimizer, train_loader, val_loader, num_epochs=30, device='cpu'):
        self.to(device)
        for epoch in range(num_epochs):
            self.train()
            train_progress = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
            
            for i, (x, y) in train_progress:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = self(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                optimizer.step()
                train_progress.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            self.eval()
            total_accuracy = 0
            total_loss = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    y_pred = self(x)
                    total_loss += self.loss(y_pred, y).item()
                    total_accuracy += self.get_accuracy(y_pred, y)

                avg_loss = total_loss / len(val_loader)
                accuracy = total_accuracy / len(val_loader)
                print(f"Epoch {epoch + 1}, Validation Accuracy: {accuracy*100:.2f}%, Validation Loss: {avg_loss:.6f}")

In [15]:
train_data_multilabel, train_labels_multilabel, val_data_multilabel, val_labels_multilabel, test_data_multilabel, test_labels_multilabel = load_mnist_data_multilabel(data_path)

In [16]:
train_labels_multilabel_one_hot = [get_one_hot(y) for y in train_labels_multilabel]
val_labels_multilabel_one_hot = [get_one_hot(y) for y in val_labels_multilabel]
test_labels_multilabel_one_hot = [get_one_hot(y) for y in test_labels_multilabel]

In [17]:

train_data_multilabel = MultiMNISTDataset_mulitlabel(train_data_multilabel, train_labels_multilabel_one_hot)
val_data_multilabel = MultiMNISTDataset_mulitlabel(val_data_multilabel, val_labels_multilabel_one_hot)
test_data_multilabel = MultiMNISTDataset_mulitlabel(test_data_multilabel, test_labels_multilabel_one_hot)

train_loader_multilabel = DataLoader(train_data_multilabel, batch_size=32, shuffle=True)
val_loader_multilabel = DataLoader(val_data_multilabel, batch_size=32, shuffle=False)
test_loader_multilabel = DataLoader(test_data_multilabel, batch_size=32, shuffle=False)

In [18]:
model_multilabel_classfication = CNN_Multilabel(task='multilabel_classification', num_classes=33).to(device)

optimizer = torch.optim.Adam(model_multilabel_classfication.parameters(), lr=0.001)
model_multilabel_classfication.train_model(optimizer, train_loader_multilabel, val_loader_multilabel, num_epochs=50, device=device)

Epoch 1/50: 100%|██████████| 394/394 [00:07<00:00, 51.94it/s, Loss=4.0004]


Epoch 1, Validation Accuracy: 3.36%, Validation Loss: 4.991885


Epoch 2/50: 100%|██████████| 394/394 [00:08<00:00, 45.60it/s, Loss=5.0482]


Epoch 2, Validation Accuracy: 4.99%, Validation Loss: 4.996034


Epoch 3/50: 100%|██████████| 394/394 [00:07<00:00, 51.55it/s, Loss=3.8705]


Epoch 3, Validation Accuracy: 5.05%, Validation Loss: 4.956559


Epoch 4/50: 100%|██████████| 394/394 [00:08<00:00, 45.94it/s, Loss=4.3851]


Epoch 4, Validation Accuracy: 4.39%, Validation Loss: 4.793960


Epoch 5/50: 100%|██████████| 394/394 [00:07<00:00, 52.85it/s, Loss=4.6114]


Epoch 5, Validation Accuracy: 4.62%, Validation Loss: 4.836658


Epoch 6/50: 100%|██████████| 394/394 [00:07<00:00, 51.71it/s, Loss=2.8655]


Epoch 6, Validation Accuracy: 5.05%, Validation Loss: 4.652598


Epoch 7/50: 100%|██████████| 394/394 [00:08<00:00, 44.46it/s, Loss=3.6225]


Epoch 7, Validation Accuracy: 5.45%, Validation Loss: 4.497826


Epoch 8/50: 100%|██████████| 394/394 [00:12<00:00, 32.69it/s, Loss=2.8918]


Epoch 8, Validation Accuracy: 6.31%, Validation Loss: 4.500981


Epoch 9/50: 100%|██████████| 394/394 [00:12<00:00, 31.04it/s, Loss=2.2456]


Epoch 9, Validation Accuracy: 5.16%, Validation Loss: 4.717444


Epoch 10/50: 100%|██████████| 394/394 [00:09<00:00, 43.54it/s, Loss=2.6029]


Epoch 10, Validation Accuracy: 6.79%, Validation Loss: 4.606946


Epoch 11/50: 100%|██████████| 394/394 [00:09<00:00, 43.54it/s, Loss=2.3288]


Epoch 11, Validation Accuracy: 8.20%, Validation Loss: 4.438222


Epoch 12/50: 100%|██████████| 394/394 [00:10<00:00, 37.34it/s, Loss=2.5140]


Epoch 12, Validation Accuracy: 7.91%, Validation Loss: 4.646089


Epoch 13/50: 100%|██████████| 394/394 [00:08<00:00, 44.01it/s, Loss=2.0235]


Epoch 13, Validation Accuracy: 10.20%, Validation Loss: 4.399217


Epoch 14/50: 100%|██████████| 394/394 [00:10<00:00, 38.93it/s, Loss=1.7224]


Epoch 14, Validation Accuracy: 10.44%, Validation Loss: 4.463925


Epoch 15/50: 100%|██████████| 394/394 [00:08<00:00, 44.84it/s, Loss=2.6382]


Epoch 15, Validation Accuracy: 13.34%, Validation Loss: 4.532721


Epoch 16/50: 100%|██████████| 394/394 [00:08<00:00, 43.84it/s, Loss=1.6634]


Epoch 16, Validation Accuracy: 16.81%, Validation Loss: 3.878631


Epoch 17/50: 100%|██████████| 394/394 [00:10<00:00, 39.30it/s, Loss=2.1148]


Epoch 17, Validation Accuracy: 18.40%, Validation Loss: 4.197657


Epoch 18/50: 100%|██████████| 394/394 [00:08<00:00, 44.71it/s, Loss=1.6613]


Epoch 18, Validation Accuracy: 19.04%, Validation Loss: 4.263375


Epoch 19/50: 100%|██████████| 394/394 [00:09<00:00, 39.80it/s, Loss=0.9717]


Epoch 19, Validation Accuracy: 21.49%, Validation Loss: 4.427503


Epoch 20/50: 100%|██████████| 394/394 [00:08<00:00, 44.19it/s, Loss=1.1859]


Epoch 20, Validation Accuracy: 25.08%, Validation Loss: 3.755477


Epoch 21/50: 100%|██████████| 394/394 [00:09<00:00, 43.50it/s, Loss=1.7609]


Epoch 21, Validation Accuracy: 26.12%, Validation Loss: 3.850917


Epoch 22/50: 100%|██████████| 394/394 [00:08<00:00, 45.08it/s, Loss=2.2319]


Epoch 22, Validation Accuracy: 27.08%, Validation Loss: 4.150705


Epoch 23/50: 100%|██████████| 394/394 [00:09<00:00, 43.44it/s, Loss=1.2351]


Epoch 23, Validation Accuracy: 26.96%, Validation Loss: 3.797474


Epoch 24/50: 100%|██████████| 394/394 [00:10<00:00, 36.26it/s, Loss=1.3794]


Epoch 24, Validation Accuracy: 28.00%, Validation Loss: 3.922247


Epoch 25/50: 100%|██████████| 394/394 [00:09<00:00, 42.35it/s, Loss=0.6058]


Epoch 25, Validation Accuracy: 28.70%, Validation Loss: 3.787590


Epoch 26/50: 100%|██████████| 394/394 [00:09<00:00, 43.59it/s, Loss=0.7118]


Epoch 26, Validation Accuracy: 30.18%, Validation Loss: 3.990231


Epoch 27/50: 100%|██████████| 394/394 [00:09<00:00, 41.33it/s, Loss=0.5401]


Epoch 27, Validation Accuracy: 30.86%, Validation Loss: 4.027424


Epoch 28/50: 100%|██████████| 394/394 [00:09<00:00, 43.47it/s, Loss=1.3465]


Epoch 28, Validation Accuracy: 31.15%, Validation Loss: 4.167344


Epoch 29/50: 100%|██████████| 394/394 [00:11<00:00, 34.94it/s, Loss=0.9452]


Epoch 29, Validation Accuracy: 32.28%, Validation Loss: 4.035338


Epoch 30/50: 100%|██████████| 394/394 [00:10<00:00, 37.95it/s, Loss=0.5327]


Epoch 30, Validation Accuracy: 32.70%, Validation Loss: 3.833912


Epoch 31/50: 100%|██████████| 394/394 [00:09<00:00, 40.88it/s, Loss=0.2377]


Epoch 31, Validation Accuracy: 29.43%, Validation Loss: 4.307982


Epoch 32/50: 100%|██████████| 394/394 [00:09<00:00, 40.80it/s, Loss=0.5919]


Epoch 32, Validation Accuracy: 32.27%, Validation Loss: 4.178316


Epoch 33/50: 100%|██████████| 394/394 [00:09<00:00, 43.08it/s, Loss=0.7737]


Epoch 33, Validation Accuracy: 30.11%, Validation Loss: 4.958003


Epoch 34/50: 100%|██████████| 394/394 [00:10<00:00, 39.39it/s, Loss=0.4810]


Epoch 34, Validation Accuracy: 32.01%, Validation Loss: 4.260702


Epoch 35/50: 100%|██████████| 394/394 [00:08<00:00, 43.89it/s, Loss=0.3419]


Epoch 35, Validation Accuracy: 32.85%, Validation Loss: 4.297442


Epoch 36/50: 100%|██████████| 394/394 [00:08<00:00, 44.35it/s, Loss=0.8156]


Epoch 36, Validation Accuracy: 32.27%, Validation Loss: 4.672088


Epoch 37/50: 100%|██████████| 394/394 [00:09<00:00, 41.18it/s, Loss=0.6923]


Epoch 37, Validation Accuracy: 32.13%, Validation Loss: 4.996537


Epoch 38/50: 100%|██████████| 394/394 [00:08<00:00, 44.26it/s, Loss=0.4375]


Epoch 38, Validation Accuracy: 33.74%, Validation Loss: 4.771185


Epoch 39/50: 100%|██████████| 394/394 [00:09<00:00, 39.93it/s, Loss=0.5671]


Epoch 39, Validation Accuracy: 32.37%, Validation Loss: 5.047542


Epoch 40/50: 100%|██████████| 394/394 [00:09<00:00, 43.08it/s, Loss=0.2409]


Epoch 40, Validation Accuracy: 33.00%, Validation Loss: 4.994994


Epoch 41/50: 100%|██████████| 394/394 [00:08<00:00, 43.84it/s, Loss=0.2453]


Epoch 41, Validation Accuracy: 32.57%, Validation Loss: 4.863742


Epoch 42/50: 100%|██████████| 394/394 [00:09<00:00, 42.88it/s, Loss=0.4859]


Epoch 42, Validation Accuracy: 32.03%, Validation Loss: 5.535784


Epoch 43/50: 100%|██████████| 394/394 [00:09<00:00, 42.91it/s, Loss=0.2344]


Epoch 43, Validation Accuracy: 34.94%, Validation Loss: 5.118845


Epoch 44/50: 100%|██████████| 394/394 [00:10<00:00, 38.51it/s, Loss=0.1686]


Epoch 44, Validation Accuracy: 33.27%, Validation Loss: 5.323517


Epoch 45/50: 100%|██████████| 394/394 [00:09<00:00, 41.89it/s, Loss=0.3570]


Epoch 45, Validation Accuracy: 33.62%, Validation Loss: 5.373346


Epoch 46/50: 100%|██████████| 394/394 [00:09<00:00, 42.39it/s, Loss=0.1316]


Epoch 46, Validation Accuracy: 33.71%, Validation Loss: 5.709679


Epoch 47/50: 100%|██████████| 394/394 [00:08<00:00, 44.37it/s, Loss=0.2602]


Epoch 47, Validation Accuracy: 33.69%, Validation Loss: 5.974141


Epoch 48/50: 100%|██████████| 394/394 [00:08<00:00, 44.07it/s, Loss=0.1508]


Epoch 48, Validation Accuracy: 34.84%, Validation Loss: 5.392631


Epoch 49/50: 100%|██████████| 394/394 [00:09<00:00, 40.48it/s, Loss=0.1311]


Epoch 49, Validation Accuracy: 33.80%, Validation Loss: 6.574523


Epoch 50/50: 100%|██████████| 394/394 [00:09<00:00, 42.48it/s, Loss=0.1853]


Epoch 50, Validation Accuracy: 32.67%, Validation Loss: 6.374373
