In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from utils import calculate_gradient_and_update, cnn_train_step, cnn_val_step, save_checkpoint, mmn_train_step, mmn_val_step
from dataloader import CustomDataloader
from networks import CNNRegression, MultiModalNetwork
import tqdm

#### Automatic GPU optimization, if available

In [2]:
cuda_available = torch.cuda.is_available()
print("Is CUDA available? ", cuda_available)

Is CUDA available?  True


In [3]:
if (torch.cuda.is_available()):
    device = torch.device("cuda")
    print('Using GPU:', torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print('GPU not available, using CPU')
    

Using GPU: NVIDIA GeForce GTX 1660 Ti


### Model 1 - Linear Regression ###

In [4]:
#load train test and validation data from tensor_collection
cm_x_train = torch.load('tensor_collection\\cm_x_train.pt')
cm_y_train = torch.load('tensor_collection\\cm_y_train.pt')
cm_x_test = torch.load('tensor_collection\\cm_x_test.pt')
cm_y_test = torch.load('tensor_collection\\cm_y_test.pt')
cm_x_val = torch.load('tensor_collection\\cm_x_val.pt')
cm_y_val = torch.load('tensor_collection\\cm_y_val.pt')



#### Hyperparameters setup

In [5]:
num_features = cm_x_train.shape[1]

train_loss_history = []
val_loss_history = []
num_epochs = 100
alpha = 0.1

# Initialize theta to random values between -2 and 2
theta = np.random.uniform(-2, 2, (num_features))


#### Model 1 - Training Loop 

In [6]:

# Training Loop
for t in range(num_epochs):
    # Training step
    train_loss, theta = calculate_gradient_and_update(cm_x_train, cm_y_train, theta, alpha)
    train_loss_history.append(train_loss)

    # Validation step
    val_loss, _ = calculate_gradient_and_update(cm_x_val, cm_y_val, theta, 0)  # alpha=0 to prevent updates
    val_loss_history.append(val_loss)


In [7]:
#save theta and loss history for analysis notebook
theta_torch = torch.tensor(theta, dtype=torch.float32)
train_loss_history_torch = torch.tensor(train_loss_history, dtype=torch.float32)
val_loss_history_torch = torch.tensor(val_loss_history, dtype=torch.float32)

torch.save(theta_torch, 'tensor_collection\\lr_theta.pt')
torch.save(train_loss_history_torch, 'tensor_collection\\lr_train_loss.pt')
torch.save(val_loss_history_torch, 'tensor_collection\\lr_val_loss.pt')


### Model 2 - Training CNN ###

#### Load in Train, Val, and Test sets for Model 2

In [8]:
directory = 'tensor_collection\\'

#load train and val csv files
nn_x_train = pd.read_csv(directory + 'nn_x_train.csv')
nn_x_val = pd.read_csv(directory + 'nn_x_val.csv')
nn_y_train_tensor = torch.load(directory + 'nn_y_train.pt')
nn_y_val_tensor = torch.load(directory + 'nn_y_val.pt')

print(type(nn_x_train))
print(type(nn_y_train_tensor))


<class 'pandas.core.frame.DataFrame'>
<class 'numpy.ndarray'>


#### Model 2 - Hyperparameters

In [9]:
model = CNNRegression().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10
loss_fn = nn.MSELoss()
batch_sz = 20

#### Instantiating Train and Val Dataloaders

In [10]:
train_dataloader = CustomDataloader(dataframe=nn_x_train, age=nn_y_train_tensor, batch_size=batch_sz, randomize=True)
val_dataloader = CustomDataloader(dataframe=nn_x_val, age=nn_y_val_tensor, batch_size=batch_sz, randomize=False)

#### Model 2 - Training Loop

In [11]:
train_losses = []
val_losses = []
for epoch in tqdm.tqdm(range(epochs)):
    # Training step
    train_loss = cnn_train_step(model, train_dataloader, loss_fn, optimizer, device)
    train_losses.append(train_loss)
    
    # Validation step
    val_loss = cnn_val_step(model, val_dataloader, loss_fn, device)
    val_losses.append(val_loss)

    save_checkpoint(model, optimizer, epoch, f'model_checkpoints/nn_checkpoint_epoch_{epoch}.pth')

  0%|          | 0/10 [00:00<?, ?it/s]

Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  True
Step_Train: 18 :  True
Fetch: 19 :  True
Step_Train: 19 :  True
Fetch: 20 :  True
Step_Train: 20 :  True
Fetch: 21 :  True
Step_Train: 21 :  True
Fetch: 22 :  True
Step_Train: 22 :  True
Fetch: 23 :  True
Step_Train: 23 :  True
Fetch: 24 :  True
Step_Train: 24 :  True
Fetch: 25 :  True
Step_Train: 25 :

 10%|█         | 1/10 [00:01<00:10,  1.16s/it]

Fetch: 37 :  True
Step_Train: 37 :  True
Fetch: 38 :  True
Step_Train: 38 :  True
Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  

 20%|██        | 2/10 [00:02<00:07,  1.00it/s]

Fetch: 37 :  True
Step_Train: 37 :  True
Fetch: 38 :  True
Step_Train: 38 :  True
Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  

 30%|███       | 3/10 [00:02<00:06,  1.08it/s]

Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  True
Step_Train: 18 :  True
Fetch: 19 :  

 40%|████      | 4/10 [00:03<00:05,  1.13it/s]

Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  True
Step_Train: 18 :  True
Fetch: 19 :  

 50%|█████     | 5/10 [00:04<00:04,  1.15it/s]

Fetch: 38 :  True
Step_Train: 38 :  True
Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  

 60%|██████    | 6/10 [00:05<00:03,  1.17it/s]

Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  True
Step_Train: 18 :  True
Fetch: 19 :  

 70%|███████   | 7/10 [00:06<00:02,  1.17it/s]

Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  True
Step_Train: 18 :  True
Fetch: 19 :  

 80%|████████  | 8/10 [00:07<00:01,  1.17it/s]

Fetch: 38 :  True
Step_Train: 38 :  True
Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  True
Step_Train: 16 :  True
Fetch: 17 :  True
Step_Train: 17 :  True
Fetch: 18 :  

 90%|█████████ | 9/10 [00:07<00:00,  1.15it/s]

Fetch: 36 :  True
Step_Train: 36 :  True
Fetch: 37 :  True
Step_Train: 37 :  True
Fetch: 38 :  True
Step_Train: 38 :  True
Fetch: 39 :  True
Step_Train: 39 :  True
Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True
Fetch: 1 :  True
Step_Train: 1 :  True
Fetch: 2 :  True
Step_Train: 2 :  True
Fetch: 3 :  True
Step_Train: 3 :  True
Fetch: 4 :  True
Step_Train: 4 :  True
Fetch: 5 :  True
Step_Train: 5 :  True
Fetch: 6 :  True
Step_Train: 6 :  True
Fetch: 7 :  True
Step_Train: 7 :  True
Fetch: 8 :  True
Step_Train: 8 :  True
Fetch: 9 :  True
Step_Train: 9 :  True
Fetch: 10 :  True
Step_Train: 10 :  True
Fetch: 11 :  True
Step_Train: 11 :  True
Fetch: 12 :  True
Step_Train: 12 :  True
Fetch: 13 :  True
Step_Train: 13 :  True
Fetch: 14 :  True
Step_Train: 14 :  True
Fetch: 15 :  True
Step_Train: 15 :  True
Fetch: 16 :  

100%|██████████| 10/10 [00:08<00:00,  1.12it/s]

Fetch: 40 :  True
Step_Train: 40 :  True
Fetch: 1 :  True
Step_Val: 1 :  True
Fetch: 2 :  True
Step_Val: 2 :  True
Fetch: 3 :  True
Step_Val: 3 :  True
Fetch: 4 :  True
Step_Val: 4 :  True
Fetch: 5 :  True
Step_Val: 5 :  True





In [12]:
#store train and validation losses
torch.save(train_losses, 'tensor_collection\\cnn_train_losses.pt')
torch.save(val_losses, 'tensor_collection\\cnn_val_losses.pt')




### Model 3 - Training Multi-Modal Network ###

#### Model 3 - Hyperparameters

In [13]:
num_numerical_features = nn_x_train.shape[1] - 1 # -1 because we don't count the filename column

model = MultiModalNetwork(num_numerical_features).to(device)  # Adjust for the number of numerical features
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10 #after 10 epochs the model train and validation loss began to spike
loss_fn = torch.nn.MSELoss()
batch_sz = 200


#### Instantiating Train and Val Data loaders

In [14]:
train_dataloader = CustomDataloader(dataframe=nn_x_train, age=nn_y_train_tensor, batch_size=batch_sz, randomize=True)
val_dataloader = CustomDataloader(dataframe=nn_x_val, age=nn_y_val_tensor, batch_size=batch_sz, randomize=False)


#### Model 3 - Training Loop

In [15]:
train_losses = []
val_losses = []
for epoch in tqdm.tqdm(range(epochs)):
    # Training step
    train_loss = mmn_train_step(model, train_dataloader, loss_fn, optimizer, device)
    train_losses.append(train_loss)

    # Validation step
    val_loss = mmn_val_step(model, val_dataloader, loss_fn, device)
    val_losses.append(val_loss)

    checkpoint_filename = f'model_checkpoints/mmn_checkpoint_epoch_{epoch}.pth'
    save_checkpoint(model, optimizer, epoch, checkpoint_filename)

  0%|          | 0/10 [00:00<?, ?it/s]

Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


 10%|█         | 1/10 [00:00<00:07,  1.16it/s]

Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True


 20%|██        | 2/10 [00:01<00:06,  1.19it/s]

Fetch: 4 :  True
Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


 30%|███       | 3/10 [00:02<00:05,  1.18it/s]

Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True


 40%|████      | 4/10 [00:03<00:05,  1.17it/s]

Fetch: 4 :  True
Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


 50%|█████     | 5/10 [00:04<00:04,  1.17it/s]

Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


 60%|██████    | 6/10 [00:05<00:03,  1.15it/s]

Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


 70%|███████   | 7/10 [00:06<00:02,  1.14it/s]

Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True


 80%|████████  | 8/10 [00:06<00:01,  1.15it/s]

Fetch: 4 :  True
Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True


 90%|█████████ | 9/10 [00:07<00:00,  1.15it/s]

Fetch: 4 :  True
Fetch: 1 :  True
Fetch: 1 :  True
Fetch: 2 :  True
Fetch: 3 :  True
Fetch: 4 :  True


100%|██████████| 10/10 [00:08<00:00,  1.16it/s]

Fetch: 1 :  True





In [16]:
#store train and validation losses
torch.save(train_losses, 'tensor_collection\\mmn_train_losses.pt')
torch.save(val_losses, 'tensor_collection\\mmn_val_losses.pt')
#store epochs for plotting
torch.save(epochs, 'tensor_collection\\mmn_epochs.pt')

