In [1]:
import torch 
import torch.optim as optim
from torch.utils.data import DataLoader 
import numpy as np
import matplotlib.pyplot as plt

from AE_functions import *
from make_dataset import *

%matplotlib qt

import warnings
warnings.filterwarnings("ignore")

# First try, simple NN

Training on one image with Adam optimizer (lr = 1e-5) and AdamW optimizer (lr = 1e-3)

In [2]:
#Define paramters
parameters_dict = {
    'epochs': 500,
    'learning_rate': 1e-3, #NOTE - change here
    'batch_size': 1, 
    'weight_decay': 5e-4 
}

## Unpack parameters
num_epochs = parameters_dict['epochs']
lr = parameters_dict['learning_rate']
batch_size = parameters_dict['batch_size']
wd = parameters_dict['weight_decay']


## Loading data

img_dir_training = "C:/Users/julie/Bachelor_data/crops_training_prep/img"
heatmap_dir_training = "C:/Users/julie/Bachelor_data/crops_training_prep/heatmaps"
msk_dir_training = "C:/Users/julie/Bachelor_data/crops_training_prep/msk"


VerSe_train = LoadData(img_dir=img_dir_training, msk_dir = msk_dir_training, distfield_dir=heatmap_dir_training)
train_loader = DataLoader(VerSe_train, batch_size=batch_size, shuffle=True, num_workers=0)
    # 39 elements (images) in train_loader
    # Each element is a tuple of 3 elements: (img, heatmap, msk)
    # img: torch.Size([2, 128, 128, 96])

input_train, y, z = train_loader.dataset[10]
# plt.imshow(input_train[0][64, :, :], cmap='gray')
# plt.title('Original')
# plt.show()


## Define model
# For simple AE
model = AE([128*96, 512, 256, 128])
# model = AE([128*96, 512, 256, 128, 64])
# model = AE([128*96, 2*512, 2*256, 2*128, 2*64])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
model.to(device)
print(model)

# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)


o_loss = []
train_loss = []
val_loss = []

recon_img = []
recon_img.append(input_train[0][64,:,:].numpy())
recon_img_name = []
recon_img_name.append('Original')

## Train model
def train(model, optimizer, epochs, device):
    model.train()
    step = -1

    for epoch in range(epochs):

        overall_loss = 0

        x = input_train[0][64,:,:].unsqueeze(dim=0)
        x = x.view(1, -1)
        x = x.to(device)

        x_reconstructed = model(x)

        #-- Loss function
        loss = loss_function(x_reconstructed, x)

        overall_loss += loss.item()

        optimizer.zero_grad()
        loss. backward()
        optimizer.step()

        # Update step
        step+=1

        # Do evaluation every 50 epoch
        if step%25 == 0:
            print()
            print("EVALUATION!")
            model.eval() #Set to evaluation

            #Training evaluation
            val_loss_eval = []
            with torch.no_grad():
                inputs = input_train[0][64,:,:].unsqueeze(dim=0)
                inputs = inputs.view(1, -1)

                inputs = inputs.to(device)
                inputs_reconstructed = model(inputs)
                
                #-- Loss function
                v_loss = loss_function(inputs_reconstructed, inputs)

                #-- Save image
                if step%50 == 0:
                    recon_img.append(inputs_reconstructed.detach().cpu().numpy().reshape(128, 96))
                    recon_img_name.append('Reconstructed image, '+str(step)+' epochs')


                # Save loss
                val_loss_eval.append(v_loss.item())
            avg_loss_val = np.mean(val_loss_eval)
            print("Validation loss: "+str(avg_loss_val))
            val_loss.append(avg_loss_val)


        o_loss.append(overall_loss)  
    

train(model, optimizer, num_epochs, device=device)

AE(
  (encoder): Sequential(
    (0): Linear(in_features=12288, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=12288, bias=True)
    (5): Sigmoid()
  )
)

EVALUATION!
Validation loss: 0.968811571598053

EVALUATION!
Validation loss: 0.2670501470565796

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 0.26708781719207764

EVALUATION!
Validation loss: 

In [1]:
## Plotting losses from training

print(f"Converged towards {np.mean(o_loss[350:])}")

fig, ax = plt.subplots()
ax.set_title(f'Model Loss, batch_size=1, lr={lr}, wd=0.0005')
ax.set_xlabel('Epoch')
ax.set_ylabel('Avg. loss')
ax.set_xticks(np.arange(0, len(o_loss), step= 25))

ax.plot(list(range(1, len(o_loss)+1, 1)), o_loss, label='Training loss', color='b')  # Update the plot with the current loss
ax.plot(list(range(25, len(o_loss)+1, 25)), val_loss, label='Validation loss', color='r')

ax.legend()
plt.show()

NameError: name 'np' is not defined

In [4]:
## Plotting reconstructed images

fig, ax  = plt.subplots(2, 3, figsize=(12, 15))
ii = [0, 1, 2, 3, 6, 10]

for i in range(1,6):
    ax[i//3, i%3].imshow(recon_img[ii[i]], cmap='gray')
    ax[i//3, i%3].set_title(recon_img_name[ii[i]])


ax[0,0].imshow(input_train[0][64, :, :], cmap='gray')
ax[0,0].set_title('Original Image')

fig.show()


In [3]:
## Make reconstruction

model.eval()
input_train, y, z = train_loader.dataset[10]
org_img = input_train[0][64,:,:].unsqueeze(dim=0)
x = org_img.view(1, -1)

x = x.to(device)
x_reconstructed = model(x) 
print(f'loss={loss_function(x_reconstructed, x)}')

x_reconstructed = x_reconstructed.detach().cpu().numpy().reshape(128, 96)

diff_img = org_img.squeeze() - x_reconstructed

print(torch.min(diff_img), torch.max(diff_img))


## Plotting the difference from outlier input to reconstructions
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow(org_img.squeeze(), cmap='gray')
ax[0].set_title('Original Image')

ax[1].imshow(x_reconstructed, cmap='gray')
ax[1].set_title('Reconstructed Image')

diff_plot = ax[2].imshow(org_img.squeeze() - x_reconstructed, vmin=-2, vmax=2, cmap='bwr')
ax[2].set_title('Difference')
fig.colorbar(diff_plot, ax=ax[2])


plt.show()

loss=0.26708781719207764
tensor(-0.9077) tensor(0.6194)
