Intializing the functions + parameters + data for the training and test.

In [1]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from tqdm import tqdm
from models import*

def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        
        for data, labels in tqdm(train_loader, desc=f'Epoch {epoch}/{num_epochs}', unit='batch'):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        average_loss = total_loss / len(train_loader)
        accuracy = correct / total
        
        print(f'Epoch {epoch}/{num_epochs}, Average Loss: {average_loss:.4f}, Accuracy: {accuracy * 100:.2f}%')

    return model

def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, labels in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')

def view_data_sample(loader):
    image, label = next(iter(loader))
    plt.figure(figsize=(16, 8))
    plt.axis('off')
    plt.imshow(make_grid(image, nrow=16).permute((1, 2, 0)))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def splice_batch(X, Y, num_of_labels, prints=True):
    if prints:
        print('input: ', end="")
        print("\t X shape: ", X.shape, end='\t')
        print("\t Y shape: ", Y.shape)
    X = X[Y < num_of_labels]
    Y = Y[Y < num_of_labels]
    if prints:
        print('output: ', end="")
        print("\t X shape: ", X.shape, end='\t')
        print("\t Y shape: ", Y.shape)
    return X, Y



# Parameters
batch_size = 512
lr = 0.001
num_epochs = 10

# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

# Data for the first model, only 3 classes
train_data_3, train_labels_3 = splice_batch(trainset.data, trainset.targets, num_of_labels=3)
train_data_3 = train_data_3.float() / 255.0  # Normalization
train_dataset_3 = torch.utils.data.TensorDataset(train_data_3, train_labels_3)
train_loader_3 = torch.utils.data.DataLoader(train_dataset_3, batch_size=batch_size, shuffle=True)

test_data_3, test_labels_3 = splice_batch(testset.data, testset.targets, num_of_labels=3)
test_data_3 = test_data_3.float() / 255.0  # Normalization
test_dataset_3 = torch.utils.data.TensorDataset(test_data_3, test_labels_3)
test_loader_3 = torch.utils.data.DataLoader(test_dataset_3, batch_size=batch_size, shuffle=False)



# Data for the first model, only 7 classes
train_data_7, train_labels_7 = splice_batch(trainset.data, trainset.targets, num_of_labels=7)
train_data_7 = train_data_7.float() / 255.0  # Normalization
train_dataset_7 = torch.utils.data.TensorDataset(train_data_7, train_labels_7)
train_loader_7 = torch.utils.data.DataLoader(train_dataset_7, batch_size=batch_size, shuffle=True)

test_data_7, test_labels_7 = splice_batch(testset.data, testset.targets, num_of_labels=7)
test_data_7 = test_data_7.float() / 255.0  # Normalization
test_dataset_7 = torch.utils.data.TensorDataset(test_data_7, test_labels_7)
test_loader_7 = torch.utils.data.DataLoader(test_dataset_7, batch_size=batch_size, shuffle=False)

input: 	 X shape:  torch.Size([60000, 28, 28])		 Y shape:  torch.Size([60000])
output: 	 X shape:  torch.Size([18000, 28, 28])		 Y shape:  torch.Size([18000])
input: 	 X shape:  torch.Size([10000, 28, 28])		 Y shape:  torch.Size([10000])
output: 	 X shape:  torch.Size([3000, 28, 28])		 Y shape:  torch.Size([3000])
input: 	 X shape:  torch.Size([60000, 28, 28])		 Y shape:  torch.Size([60000])
output: 	 X shape:  torch.Size([42000, 28, 28])		 Y shape:  torch.Size([42000])
input: 	 X shape:  torch.Size([10000, 28, 28])		 Y shape:  torch.Size([10000])
output: 	 X shape:  torch.Size([7000, 28, 28])		 Y shape:  torch.Size([7000])


In [3]:
# Train the first model
model_1 = model_1()
criterion_1 = nn.CrossEntropyLoss()
optimizer_1 = optim.Adam(model_1.parameters(), lr=lr)
model_1 = train_model(model_1, train_loader_3, criterion_1, optimizer_1, num_epochs)

# Test the first model
test_model(model_1, test_loader_3)
print(count_parameters(model_1))

Epoch 0/10: 100%|██████████| 36/36 [00:00<00:00, 149.51batch/s]


Epoch 0/10, Average Loss: 0.5738, Accuracy: 82.27%


Epoch 1/10: 100%|██████████| 36/36 [00:00<00:00, 135.52batch/s]


Epoch 1/10, Average Loss: 0.1695, Accuracy: 95.38%


Epoch 2/10: 100%|██████████| 36/36 [00:00<00:00, 167.88batch/s]


Epoch 2/10, Average Loss: 0.1363, Accuracy: 96.06%


Epoch 3/10: 100%|██████████| 36/36 [00:00<00:00, 143.26batch/s]


Epoch 3/10, Average Loss: 0.1226, Accuracy: 96.31%


Epoch 4/10: 100%|██████████| 36/36 [00:00<00:00, 145.60batch/s]


Epoch 4/10, Average Loss: 0.1135, Accuracy: 96.65%


Epoch 5/10: 100%|██████████| 36/36 [00:00<00:00, 181.10batch/s]


Epoch 5/10, Average Loss: 0.1107, Accuracy: 96.82%


Epoch 6/10: 100%|██████████| 36/36 [00:00<00:00, 187.79batch/s]


Epoch 6/10, Average Loss: 0.1046, Accuracy: 96.92%


Epoch 7/10: 100%|██████████| 36/36 [00:00<00:00, 205.26batch/s]


Epoch 7/10, Average Loss: 0.1005, Accuracy: 97.07%


Epoch 8/10: 100%|██████████| 36/36 [00:00<00:00, 194.69batch/s]


Epoch 8/10, Average Loss: 0.0977, Accuracy: 97.08%


Epoch 9/10: 100%|██████████| 36/36 [00:00<00:00, 160.60batch/s]

Epoch 9/10, Average Loss: 0.0973, Accuracy: 97.07%
Test Accuracy: 96.30%
48383





In [4]:
# Train the second model
model_2 = model_2()
criterion_2 = nn.CrossEntropyLoss()
optimizer_2 = optim.Adam(model_2.parameters(), lr=lr)
model_2 = train_model(model_2, train_loader_7, criterion_2, optimizer_2, num_epochs)
# Test the second model
test_model(model_2, test_loader_7)
print(count_parameters(model_2))

Epoch 0/10: 100%|██████████| 83/83 [00:00<00:00, 191.77batch/s]


Epoch 0/10, Average Loss: 1.1087, Accuracy: 60.50%


Epoch 1/10: 100%|██████████| 83/83 [00:00<00:00, 224.22batch/s]


Epoch 1/10, Average Loss: 0.6394, Accuracy: 75.69%


Epoch 2/10: 100%|██████████| 83/83 [00:00<00:00, 203.55batch/s]


Epoch 2/10, Average Loss: 0.5620, Accuracy: 79.39%


Epoch 3/10: 100%|██████████| 83/83 [00:00<00:00, 247.20batch/s]


Epoch 3/10, Average Loss: 0.5102, Accuracy: 81.60%


Epoch 4/10: 100%|██████████| 83/83 [00:00<00:00, 201.46batch/s]


Epoch 4/10, Average Loss: 0.4811, Accuracy: 82.45%


Epoch 5/10: 100%|██████████| 83/83 [00:00<00:00, 199.95batch/s]


Epoch 5/10, Average Loss: 0.4609, Accuracy: 83.25%


Epoch 6/10: 100%|██████████| 83/83 [00:00<00:00, 159.94batch/s]


Epoch 6/10, Average Loss: 0.4447, Accuracy: 83.70%


Epoch 7/10: 100%|██████████| 83/83 [00:00<00:00, 172.30batch/s]


Epoch 7/10, Average Loss: 0.4410, Accuracy: 84.12%


Epoch 8/10: 100%|██████████| 83/83 [00:00<00:00, 175.77batch/s]


Epoch 8/10, Average Loss: 0.4345, Accuracy: 84.38%


Epoch 9/10: 100%|██████████| 83/83 [00:00<00:00, 112.01batch/s]

Epoch 9/10, Average Loss: 0.4343, Accuracy: 84.00%
Test Accuracy: 82.99%
48467





In [2]:
# Train the third model
model_3 = model_3()
criterion_3 = nn.CrossEntropyLoss()
optimizer_3 = optim.Adam(model_3.parameters(), lr=lr)
model_3 = train_model(model_3, train_loader_7, criterion_3, optimizer_3, 30)
# Test the third model
test_model(model_3, test_loader_7)
print(count_parameters(model_3))

Epoch 0/30: 100%|██████████| 83/83 [00:00<00:00, 152.63batch/s]


Epoch 0/30, Average Loss: 1.5051, Accuracy: 41.09%


Epoch 1/30: 100%|██████████| 83/83 [00:00<00:00, 174.97batch/s]


Epoch 1/30, Average Loss: 0.8960, Accuracy: 63.94%


Epoch 2/30: 100%|██████████| 83/83 [00:00<00:00, 137.76batch/s]


Epoch 2/30, Average Loss: 0.7754, Accuracy: 68.17%


Epoch 3/30: 100%|██████████| 83/83 [00:00<00:00, 173.72batch/s]


Epoch 3/30, Average Loss: 0.7120, Accuracy: 71.24%


Epoch 4/30: 100%|██████████| 83/83 [00:00<00:00, 168.07batch/s]


Epoch 4/30, Average Loss: 0.6818, Accuracy: 72.72%


Epoch 5/30: 100%|██████████| 83/83 [00:00<00:00, 96.41batch/s] 


Epoch 5/30, Average Loss: 0.6450, Accuracy: 74.64%


Epoch 6/30: 100%|██████████| 83/83 [00:00<00:00, 156.75batch/s]


Epoch 6/30, Average Loss: 0.6170, Accuracy: 76.65%


Epoch 7/30: 100%|██████████| 83/83 [00:00<00:00, 180.11batch/s]


Epoch 7/30, Average Loss: 0.5950, Accuracy: 77.60%


Epoch 8/30: 100%|██████████| 83/83 [00:00<00:00, 140.14batch/s]


Epoch 8/30, Average Loss: 0.5775, Accuracy: 78.65%


Epoch 9/30: 100%|██████████| 83/83 [00:00<00:00, 195.20batch/s]


Epoch 9/30, Average Loss: 0.5600, Accuracy: 79.33%


Epoch 10/30: 100%|██████████| 83/83 [00:00<00:00, 183.76batch/s]


Epoch 10/30, Average Loss: 0.5487, Accuracy: 79.96%


Epoch 11/30: 100%|██████████| 83/83 [00:00<00:00, 191.29batch/s]


Epoch 11/30, Average Loss: 0.5382, Accuracy: 80.23%


Epoch 12/30: 100%|██████████| 83/83 [00:00<00:00, 190.80batch/s]


Epoch 12/30, Average Loss: 0.5346, Accuracy: 80.87%


Epoch 13/30: 100%|██████████| 83/83 [00:00<00:00, 228.93batch/s]


Epoch 13/30, Average Loss: 0.5196, Accuracy: 81.19%


Epoch 14/30: 100%|██████████| 83/83 [00:00<00:00, 229.12batch/s]


Epoch 14/30, Average Loss: 0.5124, Accuracy: 81.59%


Epoch 15/30: 100%|██████████| 83/83 [00:00<00:00, 198.78batch/s]


Epoch 15/30, Average Loss: 0.5112, Accuracy: 81.52%


Epoch 16/30: 100%|██████████| 83/83 [00:00<00:00, 119.61batch/s]


Epoch 16/30, Average Loss: 0.5089, Accuracy: 81.81%


Epoch 17/30: 100%|██████████| 83/83 [00:00<00:00, 163.20batch/s]


Epoch 17/30, Average Loss: 0.4990, Accuracy: 82.09%


Epoch 18/30: 100%|██████████| 83/83 [00:00<00:00, 198.38batch/s]


Epoch 18/30, Average Loss: 0.4870, Accuracy: 82.57%


Epoch 19/30: 100%|██████████| 83/83 [00:00<00:00, 194.99batch/s]


Epoch 19/30, Average Loss: 0.4844, Accuracy: 82.80%


Epoch 20/30: 100%|██████████| 83/83 [00:00<00:00, 174.40batch/s]


Epoch 20/30, Average Loss: 0.4834, Accuracy: 82.65%


Epoch 21/30: 100%|██████████| 83/83 [00:00<00:00, 172.77batch/s]


Epoch 21/30, Average Loss: 0.4743, Accuracy: 83.05%


Epoch 22/30: 100%|██████████| 83/83 [00:00<00:00, 145.96batch/s]


Epoch 22/30, Average Loss: 0.4823, Accuracy: 83.16%


Epoch 23/30: 100%|██████████| 83/83 [00:01<00:00, 72.50batch/s] 


Epoch 23/30, Average Loss: 0.4738, Accuracy: 82.93%


Epoch 24/30: 100%|██████████| 83/83 [00:00<00:00, 148.37batch/s]


Epoch 24/30, Average Loss: 0.4684, Accuracy: 83.02%


Epoch 25/30: 100%|██████████| 83/83 [00:00<00:00, 172.20batch/s]


Epoch 25/30, Average Loss: 0.4617, Accuracy: 83.49%


Epoch 26/30: 100%|██████████| 83/83 [00:00<00:00, 154.66batch/s]


Epoch 26/30, Average Loss: 0.4608, Accuracy: 83.40%


Epoch 27/30: 100%|██████████| 83/83 [00:00<00:00, 205.12batch/s]


Epoch 27/30, Average Loss: 0.4627, Accuracy: 83.38%


Epoch 28/30: 100%|██████████| 83/83 [00:00<00:00, 90.58batch/s] 


Epoch 28/30, Average Loss: 0.4658, Accuracy: 83.26%


Epoch 29/30: 100%|██████████| 83/83 [00:00<00:00, 146.98batch/s]

Epoch 29/30, Average Loss: 0.4626, Accuracy: 83.41%
Test Accuracy: 83.91%
49267





In [2]:
# Train the forth model
lr = 0.001
num_epochs = 25
model_4 = model_4()
criterion_4 = nn.CrossEntropyLoss()
optimizer_4 = optim.Adam(model_4.parameters(), lr=lr)
model_4 = train_model(model_4, train_loader_7, criterion_4, optimizer_4, num_epochs)
# Test the forth model
test_model(model_4, test_loader_7)
print(count_parameters(model_4))

Epoch 0/25: 100%|██████████| 83/83 [01:53<00:00,  1.37s/batch]


Epoch 0/25, Average Loss: 1.6001, Accuracy: 56.61%


Epoch 1/25: 100%|██████████| 83/83 [01:47<00:00,  1.30s/batch]


Epoch 1/25, Average Loss: 1.5045, Accuracy: 66.02%


Epoch 2/25: 100%|██████████| 83/83 [01:50<00:00,  1.33s/batch]


Epoch 2/25, Average Loss: 1.4963, Accuracy: 66.83%


Epoch 3/25: 100%|██████████| 83/83 [01:45<00:00,  1.27s/batch]


Epoch 3/25, Average Loss: 1.4955, Accuracy: 66.99%


Epoch 4/25: 100%|██████████| 83/83 [01:51<00:00,  1.34s/batch]


Epoch 4/25, Average Loss: 1.4949, Accuracy: 66.96%


Epoch 5/25: 100%|██████████| 83/83 [01:42<00:00,  1.23s/batch]


Epoch 5/25, Average Loss: 1.4897, Accuracy: 67.71%


Epoch 6/25: 100%|██████████| 83/83 [01:42<00:00,  1.24s/batch]


Epoch 6/25, Average Loss: 1.4868, Accuracy: 67.73%


Epoch 7/25: 100%|██████████| 83/83 [01:44<00:00,  1.26s/batch]


Epoch 7/25, Average Loss: 1.4873, Accuracy: 67.86%


Epoch 8/25: 100%|██████████| 83/83 [01:44<00:00,  1.26s/batch]


Epoch 8/25, Average Loss: 1.4939, Accuracy: 67.27%


Epoch 9/25: 100%|██████████| 83/83 [01:41<00:00,  1.22s/batch]


Epoch 9/25, Average Loss: 1.4844, Accuracy: 68.20%


Epoch 10/25: 100%|██████████| 83/83 [01:40<00:00,  1.21s/batch]


Epoch 10/25, Average Loss: 1.4815, Accuracy: 68.17%


Epoch 11/25: 100%|██████████| 83/83 [01:42<00:00,  1.24s/batch]


Epoch 11/25, Average Loss: 1.4803, Accuracy: 68.40%


Epoch 12/25: 100%|██████████| 83/83 [01:41<00:00,  1.23s/batch]


Epoch 12/25, Average Loss: 1.4806, Accuracy: 68.42%


Epoch 13/25: 100%|██████████| 83/83 [01:41<00:00,  1.22s/batch]


Epoch 13/25, Average Loss: 1.4801, Accuracy: 68.58%


Epoch 14/25: 100%|██████████| 83/83 [01:44<00:00,  1.25s/batch]


Epoch 14/25, Average Loss: 1.4810, Accuracy: 68.60%


Epoch 15/25: 100%|██████████| 83/83 [01:37<00:00,  1.18s/batch]


Epoch 15/25, Average Loss: 1.4785, Accuracy: 68.71%


Epoch 16/25: 100%|██████████| 83/83 [01:41<00:00,  1.22s/batch]


Epoch 16/25, Average Loss: 1.4773, Accuracy: 68.75%


Epoch 17/25: 100%|██████████| 83/83 [01:42<00:00,  1.24s/batch]


Epoch 17/25, Average Loss: 1.4761, Accuracy: 68.87%


Epoch 18/25: 100%|██████████| 83/83 [01:38<00:00,  1.19s/batch]


Epoch 18/25, Average Loss: 1.4743, Accuracy: 68.85%


Epoch 19/25: 100%|██████████| 83/83 [01:42<00:00,  1.24s/batch]


Epoch 19/25, Average Loss: 1.4742, Accuracy: 69.01%


Epoch 20/25: 100%|██████████| 83/83 [01:42<00:00,  1.23s/batch]


Epoch 20/25, Average Loss: 1.4749, Accuracy: 68.91%


Epoch 21/25: 100%|██████████| 83/83 [01:40<00:00,  1.22s/batch]


Epoch 21/25, Average Loss: 1.4745, Accuracy: 69.03%


Epoch 22/25: 100%|██████████| 83/83 [01:46<00:00,  1.28s/batch]


Epoch 22/25, Average Loss: 1.4754, Accuracy: 69.01%


Epoch 23/25: 100%|██████████| 83/83 [01:46<00:00,  1.28s/batch]


Epoch 23/25, Average Loss: 1.4730, Accuracy: 69.10%


Epoch 24/25: 100%|██████████| 83/83 [01:37<00:00,  1.18s/batch]


Epoch 24/25, Average Loss: 1.4735, Accuracy: 69.23%
Test Accuracy: 68.90%
411911
