In [1]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageOps
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
from torch import nn, optim
from torchvision import transforms
from torch import autograd
import torch.nn.functional as F
import os, os.path
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from sklearn.preprocessing import OneHotEncoder
import re
import itertools
import functools
import operator

In [17]:
## Select type of data to train on
dataset = 1 # Aggregate = 1, Planar = 2, Graupel = 3, Columnar = 4

## Select correlation to use: w/ or w/out fractal dimension
correlation = 2 # Holzer and Sommerfeld = 1, Mola = 2
# Holzer and Sommerfeld fit to solid particles: use for planar, graupel, and columnar
# Mola fit to aggregates: use for aggregates
#
# Mola may perform better sometimes on the other shapes, but results for fractal dimension are
# non-physical on them: shape parameter outputs are less physically reliable

## Physics-guided loss on or off
custom_loss = 1 # 1 = True, 2 = False


In [18]:
## Set directory for images and data
dir_agg = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Aggregate_proj/aggregate_projections.csv'
dir_plan = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Planar_proj/planar_projections.csv'
dir_gr = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Graupel_proj/graupel_projections.csv'
dir_col = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Columnar_proj/columnar_projections.csv'
dir_imagg = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Aggregate_proj'
dir_implan = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Planar_proj'
dir_imgr = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Graupel_proj'
dir_imcol = '/Users/Crazz/Research Codes/CRREL files/Synthetic images/Columnar_proj'
## Set directory and name for saving model. Will overwrite a file of the same name
dir_model = '/Users/Crazz/Research Codes/CRREL files/planhs5model.pth'

In [11]:
## 
def padding(img, expected_size):
    desired_size = expected_size
    delta_width = desired_size[0] - img.size[0]
    delta_height = desired_size[1] - img.size[1]
    pad_width = delta_width // 2
    pad_height = delta_height // 2
    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
    return ImageOps.expand(img, padding)

# Load synthetic planar data & images
plan_geom = pd.read_csv(dir_plan)
plan_geom = plan_geom.rename(columns={'img_name': 'id'})
imgs = {}
path = dir_implan
valid_images = [".jpg",".gif",".png",".tga"]
for f in os.listdir(path):
    ext = os.path.splitext(f)[1]
    if ext.lower() not in valid_images:
        continue
    temp = Image.open(os.path.join(path,f))
    imgs[f] = temp.copy()
    temp.close()
plan_img = pd.DataFrame(imgs.items(), columns=['id', 'images'])
plan_geom = plan_geom.merge(plan_img, how='inner', on='id')
plan_geom.loc[plan_geom.porosity==0,['porosity']] = 0.01

# Load synthetic graupel data & images
gr_geom = pd.read_csv(dir_gr)
gr_geom['img_name'] = gr_geom['img_name'] + '.png'
gr_geom = gr_geom.rename(columns={'img_name': 'id'})
imgs = {}
path = dir_imgr
valid_images = [".jpg",".gif",".png",".tga"]
for f in os.listdir(path):
    ext = os.path.splitext(f)[1]
    if ext.lower() not in valid_images:
        continue
    temp = Image.open(os.path.join(path,f))
    imgs[f] = temp.copy()
    temp.close()
gr_img = pd.DataFrame(imgs.items(), columns=['id', 'images'])
gr_geom = gr_geom.merge(gr_img, how='inner', on='id')

# Load synthetic aggregate data & images
agg_geom = pd.read_csv(dir_agg)
agg_geom = agg_geom.drop(['Unnamed: 0'], axis=1)
agg_geom = agg_geom.rename(columns={'image filename': 'id'})
agg_geom = agg_geom.drop(['convex hull minor axis length','convex hull perimeter','rotation around x-axis','rotation around y-axis','image resolution','area','perimeter','convex hull major axis length','rotation around z-axis','Feret diameter orthogonal to maximum Feret diameter','number of monomers','liquid water path'],axis=1)
imgs = {}
path = dir_imagg
valid_images = [".jpg",".gif",".png",".tga"]
for f in os.listdir(path):
    ext = os.path.splitext(f)[1]
    if ext.lower() not in valid_images:
        continue
    temp = Image.open(os.path.join(path,f))
    imgs[f] = temp.copy()
    temp.close()
img_df = pd.DataFrame(imgs.items(), columns=['id', 'images'])
agg_geom = agg_geom.merge(img_df, how='inner', on='id')

# Load synthetic columnar data and images
col_geom = pd.read_csv(dir_col)
col_geom = col_geom.rename(columns={'img_name': 'id'})
col_geom['id'] = col_geom['id'] + '.png'
imgs = {}
path = dir_imcol
valid_images = [".jpg",".gif",".png",".tga"]
for f in os.listdir(path):
    ext = os.path.splitext(f)[1]
    if ext.lower() not in valid_images:
        continue
    temp = Image.open(os.path.join(path,f))
    imgs[f] = temp.copy()
    temp.close()
col_img = pd.DataFrame(imgs.items(), columns=['id', 'images'])
col_geom = col_geom.merge(col_img, how='inner', on='id')
col_geom.loc[col_geom.porosity==0,['porosity']] = 0.01

# create index to be fed to the NN
plan_geom['tmp_idx'] = np.arange(0,len(plan_geom),1)
gr_geom['tmp_idx'] = np.arange(0,len(gr_geom),1)
agg_geom['tmp_idx'] = np.arange(0,len(agg_geom),1)
col_geom['tmp_idx'] = np.arange(0,len(col_geom),1)
#normalize shape param data- here only a linear normalization from 0-1: divide by the max expected value
plan_geom['por_n'] = (plan_geom['porosity'] - 0)/(1-0)
plan_geom['sar_n'] = (plan_geom['sar'] - 0)/(0.5-0)
plan_geom['sph_n'] = (plan_geom['sph'] - 0)/(1-0)
plan_geom['csph_n'] = (plan_geom['cross_sph'] - 0)/(3.5-0)
plan_geom['lsph_n'] = (plan_geom['length_sph'] - 0)/(2-0)
plan_geom['Acrat_n'] = (plan_geom['A_crat'] - 0)/(1.5-0)
plan_geom['Alrat_n'] = (plan_geom['A_lrat'] - 0)/(1.5-0)
plan_geom['Df_n'] = plan_geom['Df']/3
gr_geom['por_n'] = (gr_geom['porosity'] - 0)/(1-0)
gr_geom['sar_n'] = (gr_geom['sar'] - 0)/(0.5-0)
gr_geom['sph_n'] = (gr_geom['sph'] - 0)/(1-0)
gr_geom['csph_n'] = (gr_geom['cross_sph'] - 0)/(3.5-0)
gr_geom['lsph_n'] = (gr_geom['length_sph'] - 0)/(2-0)
gr_geom['Acrat_n'] = (gr_geom['A_crat'] - 0)/(1.5-0)
gr_geom['Alrat_n'] = (gr_geom['A_lrat'] - 0)/(1.5-0)
gr_geom['Df_n'] = gr_geom['Df']/3
agg_geom['por_n'] = (agg_geom['porosity'] - 0)/(1-0)
agg_geom['sar_n'] = (agg_geom['sar'] - 0)/(0.5-0)
agg_geom['sph_n'] = (agg_geom['sph'] - 0)/(1-0)
agg_geom['csph_n'] = (agg_geom['cross_sph'] - 0)/(3.5-0)
agg_geom['lsph_n'] = (agg_geom['length_sph'] - 0)/(2-0)
agg_geom['Acrat_n'] = (agg_geom['A_crat'] - 0)/(1.5-0)
agg_geom['Alrat_n'] = (agg_geom['A_lrat'] - 0)/(1.5-0)
agg_geom['Df_n'] = agg_geom['Df']/3
col_geom['por_n'] = (col_geom['porosity'] - 0)/(1-0)
col_geom['sar_n'] = (col_geom['sar'] - 0)/(0.5-0)
col_geom['sph_n'] = (col_geom['sph'] - 0)/(1-0)
col_geom['csph_n'] = (col_geom['cross_sph'] - 0)/(3.5-0)
col_geom['lsph_n'] = (col_geom['length_sph'] - 0)/(2-0)
col_geom['Acrat_n'] = (col_geom['A_crat'] - 0)/(1.5-0)
col_geom['Alrat_n'] = (col_geom['A_lrat'] - 0)/(1.5-0)
col_geom['Df_n'] = col_geom['Df']/3
# Manually split aggregate data so that the problem is not trivial for the NN
agg_geom['snowflake_class_id'] = 4
agg_geom_train = agg_geom[(agg_geom['ID']<160.0)].reset_index(drop=True) # roughly 77% - 23% train test split
agg_geom_test = agg_geom[(agg_geom['ID']>=160.0)].reset_index(drop=True)

In [19]:
import shutil
def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir / 'checkpoint.pt'
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_dir / 'best_model.pt'
        shutil.copyfile(f_path, best_fpath)
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

class Synthetic_geom1(torch.utils.data.Dataset):
    def __init__(self, X):
        self.X = X['images']
        self.label = X['tmp_idx']
        self.sph = X['sph_n']
        self.porosity = X['por_n']
        self.sar = X['sar_n']
        self.length_sph = X['lsph_n']
        self.cross_sph = X['csph_n']
        self.class_id = X['snowflake_class_id']
        self.A_crat = X['Acrat_n']
        self.A_lrat = X['Alrat_n']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = (self.X[idx])
        label = torch.tensor(self.label[idx])
        sph = torch.tensor(self.sph[idx], dtype=torch.float32)
        porosity = torch.tensor(self.porosity[idx], dtype=torch.float32)
        sar = torch.tensor(self.sar[idx], dtype=torch.float32)
        length_sph = torch.tensor(self.length_sph[idx], dtype=torch.float32)
        cross_sph = torch.tensor(self.cross_sph[idx], dtype=torch.float32)
        class_id = torch.tensor(self.class_id[idx], dtype=torch.float32)
        A_crat = torch.tensor(self.A_crat[idx], dtype=torch.float32)
        A_lrat = torch.tensor(self.A_lrat[idx], dtype=torch.float32)
        
        transform = transforms.Compose([
            transforms.RandomInvert(0.5),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))]) # might need to change
        # seperate ys
        return label, transform(X), porosity,sar,sph,cross_sph,length_sph,class_id,A_crat,A_lrat

class Synthetic_geom2(torch.utils.data.Dataset):
    def __init__(self, X):
        self.X = X['images']
        self.label = X['tmp_idx']
        self.sph = X['sph_n']
        self.porosity = X['por_n']
        self.sar = X['sar_n']
        self.length_sph = X['lsph_n']
        self.cross_sph = X['csph_n']
        self.class_id = X['snowflake_class_id']
        self.A_crat = X['Acrat_n']
        self.A_lrat = X['Alrat_n']
        self.Df = X['Df_n']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = (self.X[idx])
        label = torch.tensor(self.label[idx])
        sph = torch.tensor(self.sph[idx], dtype=torch.float32)
        porosity = torch.tensor(self.porosity[idx], dtype=torch.float32)
        sar = torch.tensor(self.sar[idx], dtype=torch.float32)
        length_sph = torch.tensor(self.length_sph[idx], dtype=torch.float32)
        cross_sph = torch.tensor(self.cross_sph[idx], dtype=torch.float32)
        class_id = torch.tensor(self.class_id[idx], dtype=torch.float32)
        A_crat = torch.tensor(self.A_crat[idx], dtype=torch.float32)
        A_lrat = torch.tensor(self.A_lrat[idx], dtype=torch.float32)
        Df = torch.tensor(self.Df[idx], dtype=torch.float32)
        
        transform = transforms.Compose([
            transforms.RandomInvert(0.5),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))]) # might need to change
        # seperate ys
        return label, transform(X), porosity,sar,sph,cross_sph,length_sph,class_id,A_crat,A_lrat,Df

# Train settings
if dataset == 1: # agg
    batch_size = 32
    lr = 5e-5
    num_epochs = 2000
    if correlation == 1:
        agg_train = Synthetic_geom1(agg_geom_train)
        agg_test = Synthetic_geom1(agg_geom_test)
    if correlation == 2:
        agg_train = Synthetic_geom2(agg_geom_train)
        agg_test = Synthetic_geom2(agg_geom_test)
    if correlation == 3:
        agg_train = Synthetic_geom1(agg_geom_train)
        agg_test = Synthetic_geom1(agg_geom_test)
    train_size = len(agg_train)
    test_size = len(agg_test)
    train_loader = DataLoader(agg_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(agg_test)
elif dataset == 2: #plan
    batch_size = 10
    lr = 1e-6
    num_epochs = 600
    if correlation == 1:
        full_data = Synthetic_geom1(plan_geom)
    if correlation == 2:
        full_data = Synthetic_geom2(plan_geom)
    if correlation == 3:
        full_data = Synthetic_geom1(plan_geom)
    train_size = int(0.8 * len(full_data))
    test_size = len(full_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_data, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset)
elif dataset == 3: #graup
    batch_size = 3
    lr = 1e-6
    num_epochs = 500
    if correlation == 1:
        full_data = Synthetic_geom1(gr_geom)
    if correlation == 2:
        full_data = Synthetic_geom2(gr_geom)
    if correlation == 3:
        full_data = Synthetic_geom1(gr_geom)
    train_size = int(0.8 * len(full_data))
    test_size = len(full_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_data, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset)
elif dataset == 4: #col
    batch_size = 3
    lr = 1e-6
    num_epochs = 500
    if correlation == 1:
        full_data = Synthetic_geom1(col_geom)
    if correlation == 2:
        full_data = Synthetic_geom2(col_geom)
    if correlation == 3:
        full_data = Synthetic_geom1(col_geom)
    train_size = int(0.8 * len(full_data))
    test_size = len(full_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_data, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset)
print('Train set size:', train_size)
print('Test set size:', test_size)

class KGCNN1(nn.Module):
    def __init__(self):
        super(KGCNN1, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5),#, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 48, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 64, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 80, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
        )
        self.geom = nn.Sequential(
            nn.Linear(80 * 14 * 14, 50), # 48*12*12 if size is 256, *4*4 128
            nn.ReLU(True),
            nn.Dropout(0.15),
            nn.Linear(50, 7), # 7 geoms return
            nn.Softplus(),
        )
    def forward(self, image):
        top_out = self.conv(image)
        middle = top_out.view(top_out.size(0), -1)
        geoms_out = self.geom(middle)
        #print(image.size())
        #print(middle1.size())
        #print(data.size())
        return geoms_out
    
class KGCNN2(nn.Module):
    def __init__(self):
        super(KGCNN2, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5),#, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 48, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 64, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 80, 3),# stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
        )
        self.geom = nn.Sequential(
            nn.Linear(80 * 14 * 14, 50), # 48*12*12 if size is 256, *4*4 128
            nn.ReLU(True),
            nn.Dropout(0.15),
            nn.Linear(50, 8), # 8 geoms return
            nn.Softplus(),
        )
    def forward(self, image):
        top_out = self.conv(image)
        middle = top_out.view(top_out.size(0), -1)
        geoms_out = self.geom(middle)
        #print(image.size())
        #print(middle1.size())
        #print(data.size())
        return geoms_out

# Custom physics loss
def physics_loss(sar, sph, length_sph, cross_sph, A_crat, A_lrat):
    
    tiny = 0.000000001
    
    loss_c = (sph/(4*sar) - A_crat*cross_sph)**2
    loss_l = (sph/(4*sar) - length_sph*(1/(2*sar) - A_lrat))**2
    loss_c = torch.mean(loss_c)
    loss_l = torch.mean(loss_l)
    
    # penalize geom outputs if outside range
    if torch.max(porosity)>=1:
        loss_por = 100*(porosity-0.9)**2
    else:
        loss_por = torch.tensor(0.000000001, requires_grad=True)
        
    if torch.max(sar) >= 0.5:
        loss_sar = 100*(sar-0.4)**2
    else:
        loss_sar = torch.tensor(0.000000001, requires_grad=True)
        
    if torch.max(sph) >= 1:
        loss_sph = 100*(sph-0.9)**2
    else:
        loss_sph = torch.tensor(0.000000001, requires_grad=True)

    loss_pg = torch.mean(loss_c + loss_l + loss_por + loss_sar + loss_sph)/5
    
    return loss_c, loss_l, loss_pg

# use GPU for computations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(torch.cuda.is_available())
if correlation == 1:
    model = KGCNN1().to(device)  # Create an instance of the model
elif correlation == 2:
    model = KGCNN2().to(device)
#device = torch.device('cpu')

if correlation == 1:
    model = KGCNN1().to(device)
    # freeze Cd layer when training
    #    for name, param in model.named_parameters():
    #    if param.requires_grad and 'Cd' in name:
    #        param.requires_grad = False
    non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(non_frozen_parameters, lr=lr)  # Choose an optimizer

    criterion = nn.MSELoss() #MSELoss()  # Mean squared error loss #L1Loss()

    train_llist_Geom = []
    train_llist_phys = []
    train_llist_por = []
    train_llist_sar = []
    train_llist_sph = []
    train_llist_c_sph = []
    train_llist_l_sph = []
    train_llist_Acrat = []
    train_llist_Alrat = []
    train_llist_pg = []

    test_llist_Geom = []
    test_llist_phys = []
    test_llist_por = []
    test_llist_sar = []
    test_llist_sph = []
    test_llist_c_sph = []
    test_llist_l_sph = []
    test_llist_Acrat = []
    test_llist_Alrat = []
    test_llist_pg = []

    best_test_acc = 100

    # Training loop
    train_loss = []

    for epoch in range(num_epochs):
        model.train()  # Set the model in training mode

        train_Geom_loss = 0.0
        train_por_l = 0.0
        train_sar_l = 0.0
        train_sph_l = 0.0
        train_c_sph_l = 0.0
        train_l_sph_l = 0.0
        train_Acrat_l = 0.0
        train_Alrat_l = 0.0
        train_pg = 0.0

        for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat in train_loader:

            id = id.to(device)
            images = Variable(images, requires_grad=True)

            por_dat = por_dat.view(por_dat.size(0),1)
            sar_dat = sar_dat.view(por_dat.size(0),1)
            sph_dat = sph_dat.view(por_dat.size(0),1)
            c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
            l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
            A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
            A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
            class_id = class_id.view(class_id.size(0),1)

            images = images.to(device)
            por_dat = por_dat.to(device)
            sar_dat = sar_dat.to(device)
            sph_dat = sph_dat.to(device)
            c_sph_dat = c_sph_dat.to(device)
            l_sph_dat = l_sph_dat.to(device)
            A_crat_dat = A_crat_dat.to(device)
            A_lrat_dat = A_lrat_dat.to(device)
            class_id = class_id.to(device)

            #with autograd.detect_anomaly():
            # Forward pass
            geoms_out = model(images)
            sph = geoms_out[:,0]
            cross_sph = geoms_out[:,1]
            length_sph = geoms_out[:,2]
            porosity = geoms_out[:,3]
            sar = geoms_out[:,4]
            A_crat = geoms_out[:,5]
            A_lrat = geoms_out[:,6]

            porosity = porosity.view(porosity.size(0),1)
            sar = sar.view(sar.size(0),1)
            sph = sph.view(sph.size(0),1)
            cross_sph = cross_sph.view(cross_sph.size(0),1)
            length_sph = length_sph.view(length_sph.size(0),1)
            A_crat = A_crat.view(A_crat.size(0),1)
            A_lrat = A_lrat.view(A_lrat.size(0),1)

            #### LOSS
            #PHYSICS GUIDED LOSS
            loss_c,loss_l,loss_pg = physics_loss(sar*0.5,sph,length_sph*2,cross_sph*3.5,A_crat*1.5,A_lrat*1.5)

            # regular geom loss, porosity and c_sph known in mascdb
            loss_por = criterion(porosity,por_dat)
            loss_sar = criterion(sar,sar_dat)
            loss_sph = criterion(sph,sph_dat)
            loss_c_sph = criterion(cross_sph,c_sph_dat)
            loss_l_sph = criterion(length_sph,l_sph_dat)
            loss_Acrat = criterion(A_crat,A_crat_dat)
            loss_Alrat = criterion(A_lrat,A_lrat_dat)
            loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat

            # TOTAL GEOM LOSS
            if custom_loss == 1:
                loss_Geom_tot = loss_Geom + loss_pg
            if custom_loss == 2:
                loss_Geom_tot = loss_Geom

            # Backward pass and optimization
            loss_Geom_tot.backward()
            #nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            #####
            train_Geom_loss += loss_Geom * images.size(0)
            train_pg += loss_pg * images.size(0)

            train_por_l += loss_por * images.size(0)
            train_sar_l += loss_sar * images.size(0)
            train_sph_l += loss_sph * images.size(0)
            train_c_sph_l += loss_c_sph * images.size(0)
            train_l_sph_l += loss_l_sph * images.size(0)
            train_Acrat_l += loss_Acrat * images.size(0)
            train_Alrat_l += loss_Alrat * images.size(0)

        train_Geom_loss /= len(train_loader.dataset)
        train_por_l /= len(train_loader.dataset)
        train_sar_l /= len(train_loader.dataset)
        train_sph_l /= len(train_loader.dataset)
        train_c_sph_l /= len(train_loader.dataset)
        train_l_sph_l /= len(train_loader.dataset)
        train_Acrat_l /= len(train_loader.dataset)
        train_Alrat_l /= len(train_loader.dataset)
        train_pg /= len(train_loader.dataset)

        train_llist_Geom.append(train_Geom_loss.item())
        train_llist_por.append(train_por_l.item())
        train_llist_sar.append(train_sar_l.item())
        train_llist_sph.append(train_sph_l.item())
        train_llist_c_sph.append(train_c_sph_l.item())
        train_llist_l_sph.append(train_l_sph_l.item())
        train_llist_Acrat.append(train_Acrat_l.item())
        train_llist_Alrat.append(train_Alrat_l.item())
        train_llist_pg.append(train_pg.item())

        # Evaluation
        model.eval()  # Set the model in evaluation mode

        test_Geom_loss = 0.0
        test_por_l = 0.0
        test_sar_l = 0.0
        test_sph_l = 0.0
        test_c_sph_l = 0.0
        test_l_sph_l = 0.0
        test_Acrat_l = 0.0
        test_Alrat_l = 0.0
        test_pg = 0.0

        with torch.no_grad():
            for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat in test_loader:

                id = id.to(device)
                images = Variable(images, requires_grad=True)

                por_dat = por_dat.view(por_dat.size(0),1)
                sar_dat = sar_dat.view(por_dat.size(0),1)
                sph_dat = sph_dat.view(por_dat.size(0),1)
                c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
                l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
                A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
                A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
                class_id = class_id.view(class_id.size(0),1)

                images = images.to(device)
                por_dat = por_dat.to(device)
                sar_dat = sar_dat.to(device)
                sph_dat = sph_dat.to(device)
                c_sph_dat = c_sph_dat.to(device)
                l_sph_dat = l_sph_dat.to(device)
                A_crat_dat = A_crat_dat.to(device)
                A_lrat_dat = A_lrat_dat.to(device)
                class_id = class_id.to(device)

                #with autograd.detect_anomaly():
                # Forward pass
                geoms_out = model(images)
                sph = geoms_out[:,0]
                cross_sph = geoms_out[:,1]
                length_sph = geoms_out[:,2]
                porosity = geoms_out[:,3]
                sar = geoms_out[:,4]
                A_crat = geoms_out[:,5]
                A_lrat = geoms_out[:,6]

                porosity = porosity.view(porosity.size(0),1)
                sar = sar.view(sar.size(0),1)
                sph = sph.view(sph.size(0),1)
                cross_sph = cross_sph.view(cross_sph.size(0),1)
                length_sph = length_sph.view(length_sph.size(0),1)
                A_crat = A_crat.view(A_crat.size(0),1)
                A_lrat = A_lrat.view(A_lrat.size(0),1)

                #### LOSS
                #PHYSICS GUIDED LOSS
                loss_c,loss_l,loss_pg = physics_loss(sar*0.5,sph,length_sph*2,cross_sph*3.5,A_crat*1.5,A_lrat*1.5)

                # regular geom loss, porosity and c_sph known in mascdb
                loss_por = criterion(porosity,por_dat)
                loss_sar = criterion(sar,sar_dat)
                loss_sph = criterion(sph,sph_dat)
                loss_c_sph = criterion(cross_sph,c_sph_dat)
                loss_l_sph = criterion(length_sph,l_sph_dat)
                loss_Acrat = criterion(A_crat,A_crat_dat)
                loss_Alrat = criterion(A_lrat,A_lrat_dat)
                loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat

                # TOTAL GEOM LOSS
                if custom_loss == 1:
                    loss_Geom_tot = loss_Geom + loss_pg
                if custom_loss == 2:
                    loss_Geom_tot = loss_Geom

                # TOTAL loss w/o phys Cd
                # sum physics guided and regular losses from Cd and geom parameters

                test_Geom_loss += loss_Geom * images.size(0)
                test_pg += loss_pg * images.size(0)

                test_por_l += loss_por * images.size(0)
                test_sar_l += loss_sar * images.size(0)
                test_sph_l += loss_sph * images.size(0)
                test_c_sph_l += loss_c_sph * images.size(0)
                test_l_sph_l += loss_l_sph * images.size(0)
                test_Acrat_l += loss_Acrat * images.size(0)
                test_Alrat_l += loss_Alrat * images.size(0)

        test_Geom_loss /= len(test_loader.dataset)
        test_por_l /= len(test_loader.dataset)
        test_sar_l /= len(test_loader.dataset)
        test_sph_l /= len(test_loader.dataset)
        test_c_sph_l /= len(test_loader.dataset)
        test_l_sph_l /= len(test_loader.dataset)
        test_Acrat_l /= len(test_loader.dataset)
        test_Alrat_l /= len(test_loader.dataset)
        test_pg /= len(test_loader.dataset)

        test_llist_Geom.append(test_Geom_loss.item())
        test_llist_por.append(test_por_l.item())
        test_llist_sar.append(test_sar_l.item())
        test_llist_sph.append(test_sph_l.item())
        test_llist_c_sph.append(test_c_sph_l.item())
        test_llist_l_sph.append(test_l_sph_l.item())
        test_llist_Acrat.append(test_Acrat_l.item())
        test_llist_Alrat.append(test_Alrat_l.item())
        test_llist_pg.append(test_pg.item())

        if test_Geom_loss < best_test_acc:
            checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, dir_model)
            goodplanmodel_epoch = epoch
            best_test_acc = test_Geom_loss

        print(f"Epoch {epoch+1}/{num_epochs} - Train geom Loss: {train_Geom_loss:.6f} - Test Loss: {test_Geom_loss:.6f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train phys geom Loss: {train_pg:.6f} - Test Loss: {test_pg:.6f}")
        #print(f"Epoch {epoch+1}/{num_epochs} - Train phys Cd Loss: {train_p_Cd_loss:.4f} - Test Loss: {test_p_Cd_loss:.4f}")
        #print(f"Epoch {epoch+1}/{num_epochs} - Train TOTAL Loss: {train_loss_TOTAL:.4f} - Test Loss: {test_loss_TOTAL:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train por loss: {train_por_l:.4f} - Test loss: {test_por_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train sar loss: {train_sar_l:.4f} - Test loss: {test_sar_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train sph loss: {train_sph_l:.4f} - Test loss: {test_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train c_sph loss: {train_c_sph_l:.4f} - Test loss: {test_c_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train l_sph loss: {train_l_sph_l:.4f} - Test loss: {test_l_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train Acrat loss: {train_Acrat_l:.4f} - Test loss: {test_Acrat_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train Alrat loss: {train_Alrat_l:.4f} - Test loss: {test_Alrat_l:.4f}")
        print('\n')

if correlation == 2:
    model = KGCNN2().to(device)
    # freeze Cd layer when training purely for geoms
    #for name, param in model.named_parameters():
    #    if param.requires_grad and 'Cd' in name:
    #        param.requires_grad = False
    non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(non_frozen_parameters, lr=lr)  # Choose an optimizer

    criterion = nn.MSELoss() #MSELoss()  # Mean squared error loss #L1Loss()

    train_llist_Geom = []
    train_llist_phys = []
    train_llist_por = []
    train_llist_sar = []
    train_llist_sph = []
    train_llist_c_sph = []
    train_llist_l_sph = []
    train_llist_Acrat = []
    train_llist_Alrat = []
    train_llist_Df = []
    train_llist_pg = []

    test_llist_Geom = []
    test_llist_phys = []
    test_llist_por = []
    test_llist_sar = []
    test_llist_sph = []
    test_llist_c_sph = []
    test_llist_l_sph = []
    test_llist_Acrat = []
    test_llist_Alrat = []
    test_llist_Df = []
    test_llist_pg = []

    best_test_acc = 100

    # Training loop
    train_loss = []

    for epoch in range(num_epochs):
        model.train()  # Set the model in training mode

        train_Geom_loss = 0.0
        train_por_l = 0.0
        train_sar_l = 0.0
        train_sph_l = 0.0
        train_c_sph_l = 0.0
        train_l_sph_l = 0.0
        train_Acrat_l = 0.0
        train_Alrat_l = 0.0
        train_Df_l = 0.0
        train_pg = 0.0

        for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat,Df_dat in train_loader:

            id = id.to(device)
            images = Variable(images, requires_grad=True)

            por_dat = por_dat.view(por_dat.size(0),1)
            sar_dat = sar_dat.view(por_dat.size(0),1)
            sph_dat = sph_dat.view(por_dat.size(0),1)
            c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
            l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
            A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
            A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
            class_id = class_id.view(class_id.size(0),1)
            Df_dat = Df_dat.view(Df_dat.size(0),1)

            images = images.to(device)
            por_dat = por_dat.to(device)
            sar_dat = sar_dat.to(device)
            sph_dat = sph_dat.to(device)
            c_sph_dat = c_sph_dat.to(device)
            l_sph_dat = l_sph_dat.to(device)
            A_crat_dat = A_crat_dat.to(device)
            A_lrat_dat = A_lrat_dat.to(device)
            class_id = class_id.to(device)
            Df_dat = Df_dat.to(device)

            #with autograd.detect_anomaly():
            # Forward pass
            geoms_out = model(images)
            sph = geoms_out[:,0]
            cross_sph = geoms_out[:,1]
            length_sph = geoms_out[:,2]
            porosity = geoms_out[:,3]
            sar = geoms_out[:,4]
            A_crat = geoms_out[:,5]
            A_lrat = geoms_out[:,6]
            Df = geoms_out[:,7]

            porosity = porosity.view(porosity.size(0),1)
            sar = sar.view(sar.size(0),1)
            sph = sph.view(sph.size(0),1)
            cross_sph = cross_sph.view(cross_sph.size(0),1)
            length_sph = length_sph.view(length_sph.size(0),1)
            A_crat = A_crat.view(A_crat.size(0),1)
            A_lrat = A_lrat.view(A_lrat.size(0),1)
            Df = Df.view(Df.size(0),1)

            #### LOSS
            #PHYSICS GUIDED LOSS
            loss_c,loss_l,loss_pg = physics_loss(sar*0.5,sph,length_sph*2,cross_sph*3.5,A_crat*1.5,A_lrat*1.5)

            # regular geom loss, porosity and c_sph known in mascdb
            loss_por = criterion(porosity,por_dat)
            loss_sar = criterion(sar,sar_dat)
            loss_sph = criterion(sph,sph_dat)
            loss_c_sph = criterion(cross_sph,c_sph_dat)
            loss_l_sph = criterion(length_sph,l_sph_dat)
            loss_Acrat = criterion(A_crat,A_crat_dat)
            loss_Alrat = criterion(A_lrat,A_lrat_dat)
            loss_Df = criterion(Df,Df_dat)
            loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat + loss_Df

            # TOTAL GEOM LOSS
            if custom_loss == 1:
                loss_Geom_tot = loss_Geom + loss_pg
            if custom_loss == 2:
                loss_Geom_tot = loss_Geom

            # TOTAL loss w/o phys Cd
            # sum physics guided and regular losses from Cd and geom parameters
            #loss_TOTAL = loss_Geom_tot + loss_Cd + Cd_bias**2

            # Backward pass and optimization
            loss_Geom_tot.backward()
            #nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            #####
            train_Geom_loss += loss_Geom * images.size(0)
            train_pg += loss_pg * images.size(0)

            train_por_l += loss_por * images.size(0)
            train_sar_l += loss_sar * images.size(0)
            train_sph_l += loss_sph * images.size(0)
            train_c_sph_l += loss_c_sph * images.size(0)
            train_l_sph_l += loss_l_sph * images.size(0)
            train_Acrat_l += loss_Acrat * images.size(0)
            train_Alrat_l += loss_Alrat * images.size(0)
            train_Df_l += loss_Df * images.size(0)

        train_Geom_loss /= len(train_loader.dataset)
        train_por_l /= len(train_loader.dataset)
        train_sar_l /= len(train_loader.dataset)
        train_sph_l /= len(train_loader.dataset)
        train_c_sph_l /= len(train_loader.dataset)
        train_l_sph_l /= len(train_loader.dataset)
        train_Acrat_l /= len(train_loader.dataset)
        train_Alrat_l /= len(train_loader.dataset)
        train_Df_l /= len(train_loader.dataset)
        train_pg /= len(train_loader.dataset)

        train_llist_Geom.append(train_Geom_loss.item())
        train_llist_por.append(train_por_l.item())
        train_llist_sar.append(train_sar_l.item())
        train_llist_sph.append(train_sph_l.item())
        train_llist_c_sph.append(train_c_sph_l.item())
        train_llist_l_sph.append(train_l_sph_l.item())
        train_llist_Acrat.append(train_Acrat_l.item())
        train_llist_Alrat.append(train_Alrat_l.item())
        train_llist_Df.append(train_Df_l.item())
        train_llist_pg.append(train_pg.item())

        # Evaluation
        model.eval()  # Set the model in evaluation mode

        test_Geom_loss = 0.0
        test_por_l = 0.0
        test_sar_l = 0.0
        test_sph_l = 0.0
        test_c_sph_l = 0.0
        test_l_sph_l = 0.0
        test_Acrat_l = 0.0
        test_Alrat_l = 0.0
        test_Df_l = 0.0
        test_pg = 0.0

        with torch.no_grad():
            for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat,Df_dat in test_loader:

                id = id.to(device)
                images = Variable(images, requires_grad=True)

                por_dat = por_dat.view(por_dat.size(0),1)
                sar_dat = sar_dat.view(por_dat.size(0),1)
                sph_dat = sph_dat.view(por_dat.size(0),1)
                c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
                l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
                A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
                A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
                class_id = class_id.view(class_id.size(0),1)
                Df_dat = Df_dat.view(Df_dat.size(0),1)

                images = images.to(device)
                por_dat = por_dat.to(device)
                sar_dat = sar_dat.to(device)
                sph_dat = sph_dat.to(device)
                c_sph_dat = c_sph_dat.to(device)
                l_sph_dat = l_sph_dat.to(device)
                A_crat_dat = A_crat_dat.to(device)
                A_lrat_dat = A_lrat_dat.to(device)
                class_id = class_id.to(device)
                Df_dat = Df_dat.to(device)

                #with autograd.detect_anomaly():
                # Forward pass
                geoms_out = model(images)
                sph = geoms_out[:,0]
                cross_sph = geoms_out[:,1]
                length_sph = geoms_out[:,2]
                porosity = geoms_out[:,3]
                sar = geoms_out[:,4]
                A_crat = geoms_out[:,5]
                A_lrat = geoms_out[:,6]
                Df = geoms_out[:,7]

                porosity = porosity.view(porosity.size(0),1)
                sar = sar.view(sar.size(0),1)
                sph = sph.view(sph.size(0),1)
                cross_sph = cross_sph.view(cross_sph.size(0),1)
                length_sph = length_sph.view(length_sph.size(0),1)
                A_crat = A_crat.view(A_crat.size(0),1)
                A_lrat = A_lrat.view(A_lrat.size(0),1)
                Df = Df.view(Df.size(0),1)

                #### LOSS
                #PHYSICS GUIDED LOSS
                loss_c,loss_l,loss_pg = physics_loss(sar*0.5,sph,length_sph*2,cross_sph*3.5,A_crat*1.5,A_lrat*1.5)

                # regular geom loss, porosity and c_sph known in mascdb
                loss_por = criterion(porosity,por_dat)
                loss_sar = criterion(sar,sar_dat)
                loss_sph = criterion(sph,sph_dat)
                loss_c_sph = criterion(cross_sph,c_sph_dat)
                loss_l_sph = criterion(length_sph,l_sph_dat)
                loss_Acrat = criterion(A_crat,A_crat_dat)
                loss_Alrat = criterion(A_lrat,A_lrat_dat)
                loss_Df = criterion(Df,Df_dat)
                loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat + loss_Df

                # TOTAL GEOM LOSS
                if custom_loss == 1:
                    loss_Geom_tot = loss_Geom + loss_pg
                if custom_loss == 2:
                    loss_Geom_tot = loss_Geom

                # TOTAL loss w/o phys Cd
                # sum physics guided and regular losses from Cd and geom parameters

                test_Geom_loss += loss_Geom * images.size(0)
                test_pg += loss_pg * images.size(0)

                test_por_l += loss_por * images.size(0)
                test_sar_l += loss_sar * images.size(0)
                test_sph_l += loss_sph * images.size(0)
                test_c_sph_l += loss_c_sph * images.size(0)
                test_l_sph_l += loss_l_sph * images.size(0)
                test_Acrat_l += loss_Acrat * images.size(0)
                test_Alrat_l += loss_Alrat * images.size(0)

        test_Geom_loss /= len(test_loader.dataset)
        test_por_l /= len(test_loader.dataset)
        test_sar_l /= len(test_loader.dataset)
        test_sph_l /= len(test_loader.dataset)
        test_c_sph_l /= len(test_loader.dataset)
        test_l_sph_l /= len(test_loader.dataset)
        test_Acrat_l /= len(test_loader.dataset)
        test_Alrat_l /= len(test_loader.dataset)
        test_Df_l += loss_Df * images.size(0)
        test_pg /= len(test_loader.dataset)

        test_llist_Geom.append(test_Geom_loss.item())
        test_llist_por.append(test_por_l.item())
        test_llist_sar.append(test_sar_l.item())
        test_llist_sph.append(test_sph_l.item())
        test_llist_c_sph.append(test_c_sph_l.item())
        test_llist_l_sph.append(test_l_sph_l.item())
        test_llist_Acrat.append(test_Acrat_l.item())
        test_llist_Alrat.append(test_Alrat_l.item())
        test_llist_Df.append(test_Df_l.item())
        test_llist_pg.append(test_pg.item())

        if test_Geom_loss < best_test_acc:
            checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, dir_model)
            goodplanmodel_epoch = epoch
            best_test_acc = test_Geom_loss

        print(f"Epoch {epoch+1}/{num_epochs} - Train geom Loss: {train_Geom_loss:.6f} - Test Loss: {test_Geom_loss:.6f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train phys geom Loss: {train_pg:.6f} - Test Loss: {test_pg:.6f}")
        #print(f"Epoch {epoch+1}/{num_epochs} - Train phys Cd Loss: {train_p_Cd_loss:.4f} - Test Loss: {test_p_Cd_loss:.4f}")
        #print(f"Epoch {epoch+1}/{num_epochs} - Train TOTAL Loss: {train_loss_TOTAL:.4f} - Test Loss: {test_loss_TOTAL:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train por loss: {train_por_l:.4f} - Test loss: {test_por_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train sar loss: {train_sar_l:.4f} - Test loss: {test_sar_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train sph loss: {train_sph_l:.4f} - Test loss: {test_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train c_sph loss: {train_c_sph_l:.4f} - Test loss: {test_c_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train l_sph loss: {train_l_sph_l:.4f} - Test loss: {test_l_sph_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train Acrat loss: {train_Acrat_l:.4f} - Test loss: {test_Acrat_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train Alrat loss: {train_Alrat_l:.4f} - Test loss: {test_Alrat_l:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} - Train Df loss: {train_Df_l:.4f} - Test loss: {test_Df_l:.4f}")
        print('\n')
        
print('Completed training. Best model saved at epoch ',goodplanmodel_epoch)
print('Now evaluating performance')
model.load_state_dict(torch.load(dir_model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckp_path = dir_model
non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(non_frozen_parameters, lr=lr)
model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)

if correlation == 1:
    model.eval()  # Set the model in evaluation mode
    criterion = nn.MSELoss() #MSELoss()  # Mean squared error loss #L1Loss()
    outs = []
    por_out = []
    sar_out = []
    c_sph_out = []
    l_sph_out = []
    sph_out = []
    Acrat_out = []
    Alrat_out = []

    answers = []
    por_ans = []
    sar_ans = []
    c_sph_ans = []
    l_sph_ans = []
    sph_ans = []
    Acrat_ans = []
    Alrat_ans = []

    test_Geom_loss = 0.0
    test_p_Geom_loss = 0.0

    test_ids = []
    with torch.no_grad():
        for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat in test_loader:

            id = id.to(device)
            images = Variable(images, requires_grad=True)

            por_dat = por_dat.view(por_dat.size(0),1)
            sar_dat = sar_dat.view(por_dat.size(0),1)
            sph_dat = sph_dat.view(por_dat.size(0),1)
            c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
            l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
            A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
            A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
            class_id = class_id.view(class_id.size(0),1)

            images = images.to(device)
            por_dat = por_dat.to(device)
            sar_dat = sar_dat.to(device)
            sph_dat = sph_dat.to(device)
            c_sph_dat = c_sph_dat.to(device)
            l_sph_dat = l_sph_dat.to(device)
            A_crat_dat = A_crat_dat.to(device)
            A_lrat_dat = A_lrat_dat.to(device)
            class_id = class_id.to(device)

            #with autograd.detect_anomaly():
            # Forward pass
            geoms_out = model(images)
            sph = geoms_out[:,0]
            cross_sph = geoms_out[:,1]
            length_sph = geoms_out[:,2]
            porosity = geoms_out[:,3]
            sar = geoms_out[:,4]
            A_crat = geoms_out[:,5]
            A_lrat = geoms_out[:,6]

            porosity = porosity.view(porosity.size(0),1)
            sar = sar.view(sar.size(0),1)
            sph = sph.view(sph.size(0),1)
            cross_sph = cross_sph.view(cross_sph.size(0),1)
            length_sph = length_sph.view(length_sph.size(0),1)
            A_crat = A_crat.view(A_crat.size(0),1)
            A_lrat = A_lrat.view(A_lrat.size(0),1)

            #### LOSS
            #PHYSICS GUIDED LOSS
            #loss_p_Cd,loss_pg1,loss_pg2,loss_pg3,loss_pg = physics_loss(Cd_inv,porosity,sar,sph,length_sph,cross_sph,A_char,A_cross,A_length,rho_air,fallspeed,V_CH)
            #loss_p_Cd = loss_p_Cd.to(device)
            #loss_pg = loss_pg.to(device)
            loss_pg = 0.0
            # regular geom loss, porosity and c_sph known in mascdb
            loss_por = criterion(porosity,por_dat)
            loss_sar = criterion(sar,sar_dat)
            loss_sph = criterion(sph,sph_dat)
            loss_c_sph = criterion(cross_sph,c_sph_dat)
            loss_l_sph = criterion(length_sph,l_sph_dat)
            loss_Acrat = criterion(A_crat,A_crat_dat)
            loss_Alrat = criterion(A_lrat,A_lrat_dat)
            loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat

            test_Geom_loss += loss_Geom * images.size(0)
            test_p_Geom_loss += loss_pg * images.size(0)

            sph_out.append(sph.item())
            c_sph_out.append(cross_sph.item())
            l_sph_out.append(length_sph.item())
            por_out.append(porosity.item())
            sar_out.append(sar.item())
            Acrat_out.append(A_crat.item())
            Alrat_out.append(A_lrat.item())

            por_ans.append(por_dat.item())
            sar_ans.append(sar_dat.item())
            sph_ans.append(sph_dat.item())
            c_sph_ans.append(c_sph_dat.item())
            l_sph_ans.append(l_sph_dat.item())
            Acrat_ans.append(A_crat_dat.item())
            Alrat_ans.append(A_lrat_dat.item())

            test_ids.append(id.item())
            #err += (abs(outputs - t_Cd)/t_Cd)*100
            #print(err)
            #loss =  l + loss_mean
    #outs_act = np.array(outs)
    #answers_act = np.array(answers)
    por_out = np.array(por_out)
    sar_out = np.array(sar_out)
    sph_out = np.array(sph_out)
    l_sph_out = np.array(l_sph_out)
    c_sph_out = np.array(c_sph_out)
    Acrat_out = np.array(Acrat_out)
    Alrat_out = np.array(Alrat_out)
    por_out = por_out
    sar_out = sar_out*0.5
    sph_out = sph_out
    l_sph_out = l_sph_out*2
    c_sph_out = c_sph_out*3.5
    Acrat_out = Acrat_out*1.5
    Alrat_out = Alrat_out*1.5

    por_ans = np.array(por_ans)
    sar_ans = np.array(sar_ans)
    sph_ans = np.array(sph_ans)
    l_sph_ans = np.array(l_sph_ans)
    c_sph_ans = np.array(c_sph_ans)
    Acrat_ans = np.array(Acrat_ans)
    Alrat_ans = np.array(Alrat_ans)
    por_ans = por_ans
    sar_ans = sar_ans*0.5
    sph_ans = sph_ans
    l_sph_ans = l_sph_ans*2
    c_sph_ans = c_sph_ans*3.5
    Acrat_ans = Acrat_ans*1.5
    Alrat_ans = Alrat_ans*1.5

    test_ids = np.array(test_ids)
    
if correlation == 2:
    model.eval()  # Set the model in evaluation mode
    criterion = nn.MSELoss() #MSELoss()  # Mean squared error loss #L1Loss()
    outs = []
    por_out = []
    sar_out = []
    c_sph_out = []
    l_sph_out = []
    sph_out = []
    Acrat_out = []
    Alrat_out = []
    Df_out = []

    answers = []
    por_ans = []
    sar_ans = []
    c_sph_ans = []
    l_sph_ans = []
    sph_ans = []
    Acrat_ans = []
    Alrat_ans = []
    Df_ans = []

    test_Geom_loss = 0.0
    test_p_Geom_loss = 0.0

    test_ids = []
    with torch.no_grad():
        for id,images,por_dat,sar_dat,sph_dat,c_sph_dat,l_sph_dat,class_id,A_crat_dat,A_lrat_dat,Df_dat in test_loader:

            id = id.to(device)
            images = Variable(images, requires_grad=True)

            por_dat = por_dat.view(por_dat.size(0),1)
            sar_dat = sar_dat.view(por_dat.size(0),1)
            sph_dat = sph_dat.view(por_dat.size(0),1)
            c_sph_dat = c_sph_dat.view(c_sph_dat.size(0),1)
            l_sph_dat = l_sph_dat.view(l_sph_dat.size(0),1)
            A_crat_dat = A_crat_dat.view(A_crat_dat.size(0),1)
            A_lrat_dat = A_lrat_dat.view(A_lrat_dat.size(0),1)
            class_id = class_id.view(class_id.size(0),1)
            Df_dat = Df_dat.view(Df_dat.size(0),1)

            images = images.to(device)
            por_dat = por_dat.to(device)
            sar_dat = sar_dat.to(device)
            sph_dat = sph_dat.to(device)
            c_sph_dat = c_sph_dat.to(device)
            l_sph_dat = l_sph_dat.to(device)
            A_crat_dat = A_crat_dat.to(device)
            A_lrat_dat = A_lrat_dat.to(device)
            class_id = class_id.to(device)
            Df_dat = Df_dat.to(device)

            #with autograd.detect_anomaly():
            # Forward pass
            geoms_out = model(images)
            sph = geoms_out[:,0]
            cross_sph = geoms_out[:,1]
            length_sph = geoms_out[:,2]
            porosity = geoms_out[:,3]
            sar = geoms_out[:,4]
            A_crat = geoms_out[:,5]
            A_lrat = geoms_out[:,6]
            Df = geoms_out[:,7]

            porosity = porosity.view(porosity.size(0),1)
            sar = sar.view(sar.size(0),1)
            sph = sph.view(sph.size(0),1)
            cross_sph = cross_sph.view(cross_sph.size(0),1)
            length_sph = length_sph.view(length_sph.size(0),1)
            A_crat = A_crat.view(A_crat.size(0),1)
            A_lrat = A_lrat.view(A_lrat.size(0),1)
            Df = Df.view(Df.size(0),1)

            #### LOSS
            #PHYSICS GUIDED LOSS

            loss_pg = 0.0
            # regular geom loss, porosity and c_sph known in mascdb
            loss_por = criterion(porosity,por_dat)
            loss_sar = criterion(sar,sar_dat)
            loss_sph = criterion(sph,sph_dat)
            loss_c_sph = criterion(cross_sph,c_sph_dat)
            loss_l_sph = criterion(length_sph,l_sph_dat)
            loss_Acrat = criterion(A_crat,A_crat_dat)
            loss_Alrat = criterion(A_lrat,A_lrat_dat)
            loss_Df = criterion(Df,Df_dat)
            loss_Geom = loss_por + loss_sar + loss_sph + loss_c_sph + loss_l_sph + loss_Acrat + loss_Alrat + loss_Df

            test_Geom_loss += loss_Geom * images.size(0)
            test_p_Geom_loss += loss_pg * images.size(0)

            sph_out.append(sph.item())
            c_sph_out.append(cross_sph.item())
            l_sph_out.append(length_sph.item())
            por_out.append(porosity.item())
            sar_out.append(sar.item())
            Acrat_out.append(A_crat.item())
            Alrat_out.append(A_lrat.item())
            Df_out.append(Df.item())

            por_ans.append(por_dat.item())
            sar_ans.append(sar_dat.item())
            sph_ans.append(sph_dat.item())
            c_sph_ans.append(c_sph_dat.item())
            l_sph_ans.append(l_sph_dat.item())
            Acrat_ans.append(A_crat_dat.item())
            Alrat_ans.append(A_lrat_dat.item())
            Df_ans.append(Df_dat.item())

            test_ids.append(id.item())
            #err += (abs(outputs - t_Cd)/t_Cd)*100
            #print(err)
            #loss =  l + loss_mean
    #outs_act = np.array(outs)
    #answers_act = np.array(answers)
    por_out = np.array(por_out)
    sar_out = np.array(sar_out)
    sph_out = np.array(sph_out)
    l_sph_out = np.array(l_sph_out)
    c_sph_out = np.array(c_sph_out)
    Acrat_out = np.array(Acrat_out)
    Alrat_out = np.array(Alrat_out)
    Df_out = np.array(Df_out)
    por_out = por_out
    sar_out = sar_out*0.5
    sph_out = sph_out
    l_sph_out = l_sph_out*2
    c_sph_out = c_sph_out*3.5
    Acrat_out = Acrat_out*1.5
    Alrat_out = Alrat_out*1.5
    Df_out = Df_out*3

    por_ans = np.array(por_ans)
    sar_ans = np.array(sar_ans)
    sph_ans = np.array(sph_ans)
    l_sph_ans = np.array(l_sph_ans)
    c_sph_ans = np.array(c_sph_ans)
    Acrat_ans = np.array(Acrat_ans)
    Alrat_ans = np.array(Alrat_ans)
    Df_ans = np.array(Df_ans)
    por_ans = por_ans
    sar_ans = sar_ans*0.5
    sph_ans = sph_ans
    l_sph_ans = l_sph_ans*2
    c_sph_ans = c_sph_ans*3.5
    Acrat_ans = Acrat_ans*1.5
    Alrat_ans = Alrat_ans*1.5
    Df_ans = Df_ans*3

    test_ids = np.array(test_ids)
por_ans[por_ans==0] = 0.01
por_nrmse = np.sqrt(np.sum((por_out - por_ans)**2)/len((por_out - por_ans)**2))/np.mean(por_ans)*100
sar_nrmse = np.sqrt(np.sum((sar_out - sar_ans)**2)/len((sar_out - sar_ans)**2))/np.mean(sar_ans)*100
sph_nrmse = np.sqrt(np.sum((sph_out - sph_ans)**2)/len((sph_out - sph_ans)**2))/np.mean(sph_ans)*100
csph_nrmse = np.sqrt(np.sum((c_sph_out - c_sph_ans)**2)/len((c_sph_out - c_sph_ans)**2))/np.mean(c_sph_ans)*100
lsph_nrmse = np.sqrt(np.sum((l_sph_out - l_sph_ans)**2)/len((l_sph_out - l_sph_ans)**2))/np.mean(l_sph_ans)*100
Acrat_nrmse = np.sqrt(np.sum((Acrat_out - Acrat_ans)**2)/len((Acrat_out - Acrat_ans)**2))/np.mean(Acrat_ans)*100 
Alrat_nrmse = np.sqrt(np.sum((Alrat_out - Alrat_ans)**2)/len((Alrat_out - Alrat_ans)**2))/np.mean(Alrat_ans)*100
print('NRMSE scores:')
print('Porosity:',por_nrmse,'%')
print('Surface area ratio:',sar_nrmse,'%')
print('Sphericity:',sph_nrmse,'%')
print('Crosswise Sphericity:',csph_nrmse,'%')
print('Lengthwise Sphericity:',lsph_nrmse,'%')
print('Crosswise area ratio:',Acrat_nrmse,'%')
print('Lengthwise area ratio:',Alrat_nrmse,'%')
if correlation == 2:
    Df_nrmse = np.sqrt(np.sum((Df_out - Df_ans)**2)/len((Df_out - Df_ans)**2))/np.mean(Df_ans)*100
    print('Fractal dimension:',Df_nrmse,'%')

Train set size: 196
Test set size: 49
Epoch 1/600 - Train geom Loss: 1.255643 - Test Loss: 1.269276
Epoch 1/600 - Train phys geom Loss: 0.649049 - Test Loss: 0.639289
Epoch 1/600 - Train por loss: 0.1148 - Test loss: 0.1738
Epoch 1/600 - Train sar loss: 0.0426 - Test loss: 0.0399
Epoch 1/600 - Train sph loss: 0.2343 - Test loss: 0.2176
Epoch 1/600 - Train c_sph loss: 0.3127 - Test loss: 0.2980
Epoch 1/600 - Train l_sph loss: 0.3185 - Test loss: 0.3083
Epoch 1/600 - Train Acrat loss: 0.1072 - Test loss: 0.1105
Epoch 1/600 - Train Alrat loss: 0.1255 - Test loss: 0.1210


Epoch 2/600 - Train geom Loss: 1.242366 - Test Loss: 1.256449
Epoch 2/600 - Train phys geom Loss: 0.625209 - Test Loss: 0.616338
Epoch 2/600 - Train por loss: 0.1148 - Test loss: 0.1731
Epoch 2/600 - Train sar loss: 0.0423 - Test loss: 0.0394
Epoch 2/600 - Train sph loss: 0.2317 - Test loss: 0.2152
Epoch 2/600 - Train c_sph loss: 0.3078 - Test loss: 0.2933
Epoch 2/600 - Train l_sph loss: 0.3166 - Test loss: 0.3068
Epoch 

KeyboardInterrupt: 