In [1]:
import numpy as np
import seaborn as sns
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from twoChannel3DNet import twoChannel3DNet
# import twoChannel3DNet
from util import add_color, colorize, colorize_gaussian, calculate_correct_loss
from colorMNist import colorMNist
import random
import colorsys
import pickle
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

In [None]:
import importlib
importlib.reload(twoChannel3DNet)
hi = 5

In [62]:
# Load data from pickle file
cmnist_train, cmnist_val, cmnist_test = pickle.load(open("custom_datasets/color_uniform.pkl", "rb"))
print(len(cmnist_train), len(cmnist_val), len(cmnist_train) + len(cmnist_val))

# Create datasets
train_dataset = colorMNist(cmnist_train)
val_dataset = colorMNist(cmnist_val)
test_dataset = colorMNist(cmnist_test)

# Dataloaders
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                               shuffle=True, num_workers = 0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32,
                                               shuffle=True, num_workers = 0)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32,
                                               shuffle=False, num_workers = 0)

4000 1000 5000


In [63]:
# Layers of the model
model_layers = [8, 8, "M", 16,"M"]
# Create model
model = twoChannel3DNet(model_layers, 16)
# Load file and save file
lfile = "2CGaussian3D_36"
sfile = "C-2CGaussian3D_36"
# # Load model
model.load_state_dict(torch.load('model_saves/TwoChannel/' + lfile + '.pth'))
# Put model on gpu
model.cuda()
# Loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [64]:
# Freeze weights
children = [x for x in model.children()]
for x in children[0]:
    for param in x.parameters():
        param.requires_grad = False

In [65]:
# Number of epochs to train
epochs = 10

# Placeholder variables to put training and validation accuracies and losses per epoch
train_accuracies = []
train_losses = []
val_accuracies = []
val_losses = []

for epoch in range(epochs):
    print("Epoch", epoch + 1, "/", epochs)
    
    # Put model on training mode
    model.train()
    train_total_correct = 0
    train_total_loss = []
    
    for (images, labels) in tqdm(train_dataloader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Calculate number correct and loss in batch
        correct, loss = calculate_correct_loss(model, loss_fn, images, labels, model_type=3)
        
        # Backpropagation
        loss.backward()
        # Step function
        optimizer.step()
        
        # Update amount correct and loss with current batch
        train_total_correct += correct
        train_total_loss.append(loss.item())
        
    # Append epoch accuracy and loss
    train_accuracies.append(train_total_correct / len(train_dataset))
    train_losses.append(sum(train_total_loss) / len(train_total_loss))
    
    # Put model on evaluation mode
    model.eval()
    val_total_correct = 0
    val_total_loss = []
    
    # Without gradient calculation
    with torch.no_grad():
        for (images, labels) in tqdm(val_dataloader):
        
            # Calculate number correct and loss in batch
            correct, loss = calculate_correct_loss(model, loss_fn, images, labels, model_type=3)

            # Update amount correct and loss with current batch
            val_total_correct += correct
            val_total_loss.append(loss.item())

    # Append epoch accuracy and loss
    val_accuracies.append(val_total_correct / len(val_dataset))
    val_losses.append(sum(val_total_loss) / len(val_total_loss))
    

# Print accuracies and losses per epoch
for i in range(epochs):
    print("Epoch", i + 1)
    print("Train acc and loss\t", train_accuracies[i], "\t", train_losses[i])
    print("Val acc and loss\t", val_accuracies[i], "\t", val_losses[i])

Epoch 1 / 10


100%|██████████| 125/125 [00:02<00:00, 58.50it/s]
100%|██████████| 32/32 [00:00<00:00, 260.87it/s]


Epoch 2 / 10


100%|██████████| 125/125 [00:02<00:00, 58.47it/s]
100%|██████████| 32/32 [00:00<00:00, 255.70it/s]


Epoch 3 / 10


100%|██████████| 125/125 [00:02<00:00, 58.66it/s]
100%|██████████| 32/32 [00:00<00:00, 271.92it/s]


Epoch 4 / 10


100%|██████████| 125/125 [00:02<00:00, 58.79it/s]
100%|██████████| 32/32 [00:00<00:00, 263.00it/s]


Epoch 5 / 10


100%|██████████| 125/125 [00:02<00:00, 58.57it/s]
100%|██████████| 32/32 [00:00<00:00, 261.85it/s]


Epoch 6 / 10


100%|██████████| 125/125 [00:02<00:00, 59.24it/s]
100%|██████████| 32/32 [00:00<00:00, 244.87it/s]


Epoch 7 / 10


100%|██████████| 125/125 [00:02<00:00, 58.85it/s]
100%|██████████| 32/32 [00:00<00:00, 261.02it/s]


Epoch 8 / 10


100%|██████████| 125/125 [00:02<00:00, 58.47it/s]
100%|██████████| 32/32 [00:00<00:00, 241.25it/s]


Epoch 9 / 10


100%|██████████| 125/125 [00:02<00:00, 58.67it/s]
100%|██████████| 32/32 [00:00<00:00, 265.66it/s]


Epoch 10 / 10


100%|██████████| 125/125 [00:02<00:00, 58.40it/s]
100%|██████████| 32/32 [00:00<00:00, 320.84it/s]

Epoch 1
Train acc and loss	 0.85125 	 0.7409883699342609
Val acc and loss	 0.999 	 0.01124404979054816
Epoch 2
Train acc and loss	 0.99225 	 0.03481468350440264
Val acc and loss	 1.0 	 0.0051407317023404175
Epoch 3
Train acc and loss	 0.99625 	 0.016781851520761847
Val acc and loss	 1.0 	 0.0035313307998876553
Epoch 4
Train acc and loss	 0.998 	 0.009544096610974521
Val acc and loss	 0.999 	 0.0026017841641987616
Epoch 5
Train acc and loss	 0.99925 	 0.006599170670378953
Val acc and loss	 1.0 	 0.0023724187026346044
Epoch 6
Train acc and loss	 0.99925 	 0.005113052575150504
Val acc and loss	 1.0 	 0.0018756061593023787
Epoch 7
Train acc and loss	 0.9995 	 0.004484650751342997
Val acc and loss	 0.999 	 0.003527627369521724
Epoch 8
Train acc and loss	 0.99975 	 0.003290768977603875
Val acc and loss	 1.0 	 0.001835124978128988
Epoch 9
Train acc and loss	 0.99925 	 0.003290656013879925
Val acc and loss	 1.0 	 0.0014113304962961593
Epoch 10
Train acc and loss	 1.0 	 0.002994643862475641
Val




In [66]:
with open('trainvalAccs/' + sfile + '.txt', 'w') as f:
    for i in range(epochs):
        f.write("Epoch " + str(i + 1) + "\n")
        f.write("Train acc and loss\t" + str(train_accuracies[i]) + "\t" + str(train_losses[i]) + "\n")
        f.write("Val acc and loss\t" + str(val_accuracies[i]) + "\t" + str(val_losses[i]) + "\n")

In [67]:
# Save the model
torch.save(model.state_dict(), 'model_saves/ColorExperiment/'+ sfile + '.pth')

In [68]:
# Layers of the model
model_layers = [8, 8, "M", 16,"M"]
# Create model
model = twoChannel3DNet(model_layers, 16)
# model = separateChannel3DNet(model_layers, 16)
# # Load model
model.load_state_dict(torch.load('model_saves/ColorExperiment/'+ sfile + '.pth'))
# Put model on gpu
model.cuda()
hi = 5

In [69]:
# Testing
    
wrong_dict = {}
right_dict = {}

for i in range(10):
    wrong_dict[i] = {}
    for j in range(10):
        wrong_dict[i][j] = 0
    right_dict[i] = 0
    
for it in range(1):
    # Total and amount correct
    test_correct = 0
    test_total = 0

    # Put the model in evaluation mode
    model.eval()

    # Without gradient calculation
    with torch.no_grad():
        for (images, labels) in tqdm(test_dataloader):
            # # Add color to each image
            # for i in range(len(images)):
            #     if 10 not in color_dict:
            #         colorize(images[i], labels[i].item(), color_dict)
            #     else:
            #         colorize_gaussian(images[i], labels[i].item(), color_dict)
            #     # images[i] = inv_normalize(images[i])

            # Add extra dimension for the network
            images = images.unsqueeze(1)

            # print(images.shape)

            # Put images
            images = images.cuda()

            # Predicted labels
            preds = model(images)

            # Top predictions per image
            _, top_preds = torch.max(preds, 1)

            # Predictions and images back on cpu
            top_preds = top_preds.cpu()
            images = images.cpu()
            
            # Check the predicted
            for i in range(len(labels)):
                if top_preds[i].item() == labels[i].item():
                    right_dict[top_preds[i].item()] += 1
                else:
                    wrong_dict[labels[i].item()][top_preds[i].item()] += 1

            # Amount of correct predictions
            predictions = [top_preds[i].item() == labels[i].item() for i in range(len(labels))]
            correct = np.sum(predictions)
            
            # if np.sum(predictions) < len(labels):
            #     print("hi")
            #     images = images.squeeze(1)
            #     # Show batch images
            #     # fig, axs = plt.subplots(4,8, figsize=(28, 28), facecolor='w', edgecolor='k')
            #     # fig.subplots_adjust(hspace = .5, wspace=.001)
            #     # axs = axs.ravel()
            #     # for i in range(len(images)):
            #     #     axs[i].imshow(images[i].permute(1, 2, 0))
                
                
#                 index = predictions.index(0)
                
#                 plt.imshow(images[index].permute(1, 2, 0))
#                 plt.show()
                
#                 print(index, "True:", labels[index].item(), "False:", top_preds[index].item())
                # break

            # Update total correct and total images
            test_correct += correct
            test_total += len(images)


    print("Correct", test_correct, "/", test_total, "Accuracy:", test_correct / test_total)
    

100%|██████████| 313/313 [00:01<00:00, 260.66it/s]

Correct 10000 / 10000 Accuracy: 1.0





In [None]:
for x in right_dict:
    print(x, ":", right_dict[x])

In [None]:
hm = []
for i, x in enumerate(wrong_dict):
    hm.append([])
    for y in wrong_dict[x]:
        hm[i].append(wrong_dict[x][y])
    # print(x, ":", wrong_dict[x])
print(hm)

In [None]:
sns.heatmap(hm, cmap="Blues")

In [None]:
ones = torch.ones([32, 1, 3, 28, 28])
# print(og)

zeros = torch.zeros([32, 1, 28, 28])
zeros = zeros.unsqueeze(dim=2)
# print(add)

res = torch.cat([ones, zeros], dim=2)

print(res.shape)