# Imports and Mounting

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
import pickle
import os
import random
import shutil
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import torch.optim as optim

from PIL import Image
from torch.utils.data import Dataset

!pip install tqdm
from tqdm import tqdm

!pip install rdkit
from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
from rdkit.Chem import Fragments

from collections import defaultdict
import re





# Sampling the data

In [3]:
project_root = "/content/drive/Shareddrives/CIS5190FinalProj"

handdrawn_root = "/content/drive/Shareddrives/CIS5190FinalProj/DECIMER_HDM_Dataset_Images"
computer_root = "/content/drive/Shareddrives/CIS5190FinalProj/Img2Mol"

In [4]:
print(f"There are {len(os.listdir(handdrawn_root))} handdrawn images")
print(f"There are {len(os.listdir(computer_root))} computer generated images")

There are 5088 handdrawn images
There are 10880 computer generated images


In [5]:
## Training : Validation : Testing = 7 : 2 : 1
NUMBER_OF_TRAINING_IMGS = 7000
NUMBER_OF_VAL_IMGS = NUMBER_OF_TRAINING_IMGS // 7 * 2
NUMBER_OF_TESTING_IMGS = NUMBER_OF_TRAINING_IMGS // 7

train_root = project_root + "/Train0.5"
val_root = project_root + "/Val0.5"
test_root = project_root + "/Test0.5"

## Creating image sets

In [6]:
def create_new_directory(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

In [7]:
## Randomly sample the images
def split_data(comp_gen_percentage=0.5):
    if comp_gen_percentage < 0.5:
        print("Computer generated percentage cannot be smaller than 0.5")
        return

    # Computer generated images
    num_comp_gen_imgs_train = int(comp_gen_percentage * NUMBER_OF_TRAINING_IMGS)
    num_comp_gen_imgs_val = int(comp_gen_percentage * NUMBER_OF_VAL_IMGS)
    num_comp_gen_imgs_test = int(comp_gen_percentage * NUMBER_OF_TESTING_IMGS)


    files = os.listdir(computer_root) # a list of names of computer-generated images
    random.shuffle(files)

    comp_gen_filenames_train = files[:num_comp_gen_imgs_train]
    comp_gen_filenames_val = files[num_comp_gen_imgs_train: num_comp_gen_imgs_train+num_comp_gen_imgs_val]
    comp_gen_filenames_test = files[num_comp_gen_imgs_train+num_comp_gen_imgs_val:]

    # Create the directories
    create_new_directory(train_root)
    create_new_directory(val_root)
    create_new_directory(test_root)

    # Copy the images into the directories
    for filename in comp_gen_filenames_train:
        shutil.copy(os.path.join(computer_root, filename), train_root)
    for filename in comp_gen_filenames_val:
        shutil.copy(os.path.join(computer_root, filename), val_root)
    for filename in comp_gen_filenames_test:
        shutil.copy(os.path.join(computer_root, filename), test_root)



    # Handdrawn images
    num_hand_imgs_train = int((1 - comp_gen_percentage) * NUMBER_OF_TRAINING_IMGS)
    num_hand_imgs_val = int((1 - comp_gen_percentage) * NUMBER_OF_VAL_IMGS)
    num_hand_imgs_test = int((1 - comp_gen_percentage) * NUMBER_OF_TESTING_IMGS)


    files = os.listdir(handdrawn_root)
    random.shuffle(files)

    hand_filenames_train = files[:num_hand_imgs_train]
    hand_filenames_val = files[num_hand_imgs_train: num_hand_imgs_train+num_hand_imgs_val]
    hand_filenames_test = files[num_hand_imgs_train+num_hand_imgs_val:]

    # Copy the images into the directories
    for filename in hand_filenames_train:
        shutil.copy(os.path.join(handdrawn_root, filename), train_root)
    for filename in hand_filenames_val:
        shutil.copy(os.path.join(handdrawn_root, filename), val_root)
    for filename in hand_filenames_test:
        shutil.copy(os.path.join(handdrawn_root, filename), test_root)

In [None]:
## Only need to run this once
#split_data(1)

## Loading the data into a dataloader

### Get the labels

In [11]:
functional_group_functions = [
    Fragments.fr_Al_COO,
    Fragments.fr_Al_OH,
    Fragments.fr_Al_OH_noTert,
    Fragments.fr_ArN,
    Fragments.fr_Ar_COO,
    Fragments.fr_Ar_N,
    Fragments.fr_Ar_NH,
    Fragments.fr_Ar_OH,
    Fragments.fr_COO,
    Fragments.fr_COO2,
    Fragments.fr_C_O,
    Fragments.fr_C_O_noCOO,
    Fragments.fr_C_S,
    Fragments.fr_HOCCN,
    Fragments.fr_Imine,
    Fragments.fr_NH0,
    Fragments.fr_NH1,
    Fragments.fr_NH2,
    Fragments.fr_N_O,
    Fragments.fr_Ndealkylation1,
    Fragments.fr_Ndealkylation2,
    Fragments.fr_Nhpyrrole,
    Fragments.fr_SH,
    Fragments.fr_aldehyde,
    Fragments.fr_alkyl_carbamate,
    Fragments.fr_alkyl_halide,
    Fragments.fr_allylic_oxid,
    Fragments.fr_amide,
    Fragments.fr_amidine,
    Fragments.fr_aniline,
    Fragments.fr_aryl_methyl,
    Fragments.fr_azide,
    Fragments.fr_azo,
    Fragments.fr_barbitur,
    Fragments.fr_benzene,
    Fragments.fr_benzodiazepine,
    Fragments.fr_bicyclic,
    Fragments.fr_diazo,
    Fragments.fr_dihydropyridine,
    Fragments.fr_epoxide,
    Fragments.fr_ester,
    Fragments.fr_ether,
    Fragments.fr_furan,
    Fragments.fr_guanido,
    Fragments.fr_halogen,
    Fragments.fr_hdrzine,
    Fragments.fr_hdrzone,
    Fragments.fr_imidazole,
    Fragments.fr_imide,
    Fragments.fr_isocyan,
    Fragments.fr_isothiocyan,
    Fragments.fr_ketone,
    Fragments.fr_ketone_Topliss,
    Fragments.fr_lactam,
    Fragments.fr_lactone,
    Fragments.fr_methoxy,
    Fragments.fr_morpholine,
    Fragments.fr_nitrile,
    Fragments.fr_nitro,
    Fragments.fr_nitro_arom,
    Fragments.fr_nitro_arom_nonortho,
    Fragments.fr_nitroso,
    Fragments.fr_oxazole,
    Fragments.fr_oxime,
    Fragments.fr_para_hydroxylation,
    Fragments.fr_phenol,
    Fragments.fr_phenol_noOrthoHbond,
    Fragments.fr_phos_acid,
    Fragments.fr_phos_ester,
    Fragments.fr_piperdine,
    Fragments.fr_piperzine,
    Fragments.fr_priamide,
    Fragments.fr_prisulfonamd,
    Fragments.fr_pyridine,
    Fragments.fr_quatN,
    Fragments.fr_sulfide,
    Fragments.fr_sulfonamd,
    Fragments.fr_sulfone,
    Fragments.fr_term_acetylene,
    Fragments.fr_tetrazole,
    Fragments.fr_thiazole,
    Fragments.fr_thiocyan,
    Fragments.fr_thiophene,
    Fragments.fr_unbrch_alkane,
    Fragments.fr_urea,
]

In [12]:
def smiles_to_bitarray(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        raise ValueError(f"Invalid SMILES string: {smiles}")
    bit_array = np.zeros(len(functional_group_functions), dtype=int)
    for i, func in enumerate(functional_group_functions):
        bit_array[i] = 1 if func(mol) > 0 else 0
    return bit_array

In [13]:
def get_number_of_functional_groups():
    return len(functional_group_functions)

In [14]:
handdrawn_df = pd.read_csv(project_root + "/DECIMER_HDM_Dataset_SMILES.tsv", sep='\t')
handdrawn_smiles_dict = handdrawn_df.set_index('IDs')['SMILES'].to_dict()

# Process each SMILES to convert into func group arrays
handdrawn_functional_group_arrays = {key: (smiles_to_bitarray(value)) for key, value in handdrawn_smiles_dict.items()}

# To verify the transformation, print the first 5 elements of the transformed dictionary
for key in list(handdrawn_functional_group_arrays.keys())[:5]:
    print(key, ":", handdrawn_functional_group_arrays[key])

CDK_Depict_1_2 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_4 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_5 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 0 0 0 0 0 0 0 1 0]
CDK_Depict_1_6 : [0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_7 : [0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]


In [16]:
with open("/content/drive/Shareddrives/CIS5190FinalProj/Img2Mol_map.pkl", 'rb') as f:
    comp_gen_df = pickle.load(f)
comp_gen_smiles_dict = comp_gen_df.set_index('Image')['SMILES'].to_dict()
comp_gen_functional_group_arrays = {key: smiles_to_bitarray(value) for key, value in comp_gen_smiles_dict.items()}

for key in list(comp_gen_functional_group_arrays.keys())[:5]:
    print(key, ":", comp_gen_functional_group_arrays[key])

0.png : [0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 0 1
 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 1 0 0 0 0 0 1 0 0]
1.png : [0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 0 0
 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1
 0 0 1 0 0 0 0 0 0 0 0]
2.png : [0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0
 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
3.png : [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0
 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
 0 0 1 0 0 0 0 0 0 0 0]
4.png : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0
 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 1 0 0 0 0 0 0 0 0]


In [17]:
# Verify transformation for both datasets by printing the first 5 elements
print("Hand-drawn dataset first 5 atom arrays:")
for key in list(handdrawn_functional_group_arrays.keys())[:5]:
    print(key, ":", handdrawn_functional_group_arrays[key])

print("\nComputer-generated dataset first 5 atom arrays:")
for key in list(comp_gen_functional_group_arrays.keys())[:5]:
    print(key, ":", comp_gen_functional_group_arrays[key])


Hand-drawn dataset first 5 atom arrays:
CDK_Depict_1_2 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_4 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_5 : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 0 0 0 0 0 0 0 1 0]
CDK_Depict_1_6 : [0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]
CDK_Depict_1_7 : [0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0]

Comp

In [20]:
## Combine the two dictionaries together
combined_functional_group_mapping = handdrawn_functional_group_arrays.copy()

for key, value in comp_gen_functional_group_arrays.items():
    # Remove .png from the key
    key = key.replace('.png', '')
    combined_functional_group_mapping[key] = value

for key, value in comp_gen_functional_group_arrays.items():
    # Remove .png from the key
    key = key.replace('.png', '')
    combined_functional_group_mapping[key] = value

In [21]:
## Returns the label based on the filename of the image

def get_binary_label(filename):
    return combined_functional_group_mapping.get(filename, None)

In [22]:
def validate_dataset(image_names, root_dir, label_dict):
    valid_image_names = []
    for img_name in image_names:
        img_path = os.path.join(root_dir, img_name)
        if os.path.isfile(img_path) and img_name.replace(".png", "") in label_dict:
            valid_image_names.append(img_name)
    return valid_image_names


In [23]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, label_dict, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_names = validate_dataset(os.listdir(root_dir), root_dir, label_dict)
        self.label_dict = label_dict

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.label_dict[img_name.replace(".png", "")]

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.float32)
        return image, label


In [24]:
## Transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [36]:
train_functional_group_dataset = CustomDataset(root_dir=train_root, label_dict=combined_functional_group_mapping, transform=transform)
val_functional_group_dataset = CustomDataset(root_dir=val_root, label_dict=combined_functional_group_mapping, transform=transform)
test_functional_group_dataset = CustomDataset(root_dir=test_root, label_dict=combined_functional_group_mapping, transform=transform)

train_functional_group_loader = DataLoader(train_functional_group_dataset, batch_size=32, shuffle=True, num_workers=0)
val_functional_group_loader = DataLoader(val_functional_group_dataset, batch_size=32, shuffle=False, num_workers=0)
test_functional_group_loader = DataLoader(test_functional_group_dataset, batch_size=32, shuffle=False, num_workers=0)

Single point for sanity check

In [58]:
train_functional_group_dataset = train_functional_group_loader.dataset

single_item_index = 0
single_item_functional_group_dataset = Subset(train_functional_group_dataset, [single_item_index])
single_item_functional_group_loader = DataLoader(single_item_functional_group_dataset, batch_size=1, shuffle=False)
for single_image, single_label in single_item_functional_group_dataset:
    print("Image shape:", single_image.shape)
    print("Label:", single_label)



Image shape: torch.Size([3, 256, 256])
Label: tensor([0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


# Data Augmentation

In [None]:
aug1 = transforms.Compose([
    transforms.RandomHorizontalFlip(),   # Random horizontal flip
    transforms.RandomRotation(degrees=15) # Random rotation by up to 15 degrees
])

aug2 = transforms.Compose([
    transforms.RandomResizedCrop(size=224),  # Random resized crop to 224x224
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)  # Random affine transformations
])

aug3 = transforms.Compose([
    transforms.RandomHorizontalFlip(),   # Random horizontal flip
    transforms.RandomRotation(degrees=15), # Random rotation by up to 15 degrees
    transforms.RandomResizedCrop(size=224),  # Random resized crop to 224x224
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)  # Random affine transformations
])

# Model Pipeline (Transfer Learning)

In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [60]:
def calculate_accuracy(preds, labels):
    # Ensure both predictions and labels are binary (0 or 1)
    preds = preds.round().int()
    labels = labels.int()

    # Count matches where both prediction and label are 1
    correct_ones = (preds & labels).sum().item()

    # Count total cases where either prediction or label is 1
    total_ones = ((preds == 1) | (labels == 1)).sum().item()

    # Calculate accuracy based on matches and total relevant cases
    accuracy = correct_ones / total_ones if total_ones > 0 else 0.0
    return accuracy


In [61]:
def calculate_batch_accuracy(outputs, labels):
    probs = torch.sigmoid(outputs)
    preds = (probs >= 0.5).float()

    # Convert tensors to integer type for bitwise operations
    preds_int = preds.int()
    labels_int = labels.int()

    # Correct matches where both prediction and label are 1
    correct_ones = (preds_int & labels_int).sum().item()

    # Total relevant cases where either prediction or label is 1
    total_ones = ((preds_int == 1) | (labels_int == 1)).sum().item()

    # Calculate accuracy based on matches and total relevant cases
    accuracy = correct_ones / total_ones if total_ones > 0 else 0.0
    return accuracy


In [62]:
def train_single_point(model, data_loader, criterion, optimizer, num_epochs=20):
    model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        correct_train = 0.0
        total_train = 0

        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to the correct device

            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            epoch_loss += loss.item()

            # Calculate batch accuracy
            batch_correct_train = calculate_batch_accuracy(outputs, labels)
            correct_train += batch_correct_train * labels.size(0)
            total_train += labels.size(0)

        epoch_loss /= len(data_loader)
        epoch_accuracy = correct_train / total_train

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}')


In [63]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, log_interval=10):
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0.0
        total_train = 0

        print(f"Epoch [{epoch+1}/{num_epochs}]")

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            batch_correct_train = calculate_batch_accuracy(outputs, labels)
            correct_train += batch_correct_train * labels.size(0)  # Accumulate correct predictions
            total_train += labels.size(0)  # Accumulate total predictions

            if batch_idx % log_interval == 0:
                current_train_loss = running_loss / (batch_idx + 1)
                current_train_accuracy = correct_train / total_train
                print(f"Train Batch [{batch_idx}/{len(train_loader)}], Loss: {current_train_loss:.4f}, Accuracy: {current_train_accuracy:.4f}")

        train_loss = running_loss / len(train_loader)
        train_accuracy = correct_train / total_train

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct_val = 0.0
        total_val = 0

        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(val_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                batch_correct_val = calculate_batch_accuracy(outputs, labels)
                correct_val += batch_correct_val * labels.size(0)  # Accumulate correct predictions
                total_val += labels.size(0)  # Accumulate total predictions

                if batch_idx % log_interval == 0:
                    current_val_loss = val_loss / (batch_idx + 1)
                    current_val_accuracy = correct_val / total_val
                    print(f"Val Batch [{batch_idx}/{len(val_loader)}], Loss: {current_val_loss:.4f}, Accuracy: {current_val_accuracy:.4f}")

        val_loss = val_loss / len(val_loader)
        val_accuracy = correct_val / total_val

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

        # Append values for this epoch to the lists
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

    return train_losses, train_accuracies, val_losses, val_accuracies


In [64]:
def test_model(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    test_accuracy = 0.0
    total_batches = len(test_loader)

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(tqdm(test_loader, desc="Testing")):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # Calculate accuracy using the provided function
            batch_accuracy = calculate_batch_accuracy(outputs, labels)
            test_accuracy += batch_accuracy

            # Debugging: Print predictions and targets
            if i % 50 == 0:
                # Apply sigmoid and threshold to show rounded predictions
                probs = torch.sigmoid(outputs)
                preds = (probs >= 0.5).float()
                print(f'Batch {i}/{total_batches}, Current Test Loss: {test_loss/(i+1):.4f}, Current Test Accuracy: {test_accuracy/(i+1):.4f}')
                print("Predictions:", preds[:5].cpu().numpy())
                print("Targets:", labels[:5].cpu().numpy())

    test_loss /= total_batches
    test_accuracy /= total_batches

    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

    return test_loss, test_accuracy

## ResNet

In [65]:
weights = ResNet18_Weights.DEFAULT

In [71]:
num_elements = 85

In [80]:
resnet_functional_group = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_functional_group.parameters():
    param.requires_grad = False

num_single_functional_group_features = resnet_functional_group.fc.in_features
resnet_functional_group.fc = nn.Linear(num_single_functional_group_features, num_elements)
resnet_functional_group = resnet_functional_group.to(device)  # Move model to GPU

In [81]:
# Load Single Point ResNet model
resnet_single_functional_group = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_single_functional_group.parameters():
    param.requires_grad = True

num_single_functional_group_features = resnet_single_functional_group.fc.in_features
resnet_single_functional_group.fc = nn.Linear(num_single_functional_group_features, num_elements)
resnet_single_functional_group = resnet_single_functional_group.to(device)  # Move model to resnet_single_functional_group

In [82]:
criterion = nn.BCEWithLogitsLoss()
optimizer_single = torch.optim.SGD(resnet_single_functional_group.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.SGD(resnet_functional_group.parameters(), lr=0.001, momentum=0.9)

In [89]:
train_single_point(resnet_single_functional_group, single_item_functional_group_loader, criterion, optimizer_single, num_epochs=100)
test_model(resnet_single_functional_group, single_item_functional_group_loader, criterion)

Epoch [1/100], Loss: 0.1782, Accuracy: 1.0000
Epoch [2/100], Loss: 0.1764, Accuracy: 1.0000
Epoch [3/100], Loss: 0.1747, Accuracy: 1.0000
Epoch [4/100], Loss: 0.1731, Accuracy: 1.0000
Epoch [5/100], Loss: 0.1714, Accuracy: 1.0000
Epoch [6/100], Loss: 0.1698, Accuracy: 1.0000
Epoch [7/100], Loss: 0.1683, Accuracy: 1.0000
Epoch [8/100], Loss: 0.1667, Accuracy: 1.0000
Epoch [9/100], Loss: 0.1652, Accuracy: 1.0000
Epoch [10/100], Loss: 0.1637, Accuracy: 1.0000
Epoch [11/100], Loss: 0.1622, Accuracy: 1.0000
Epoch [12/100], Loss: 0.1607, Accuracy: 1.0000
Epoch [13/100], Loss: 0.1593, Accuracy: 1.0000
Epoch [14/100], Loss: 0.1579, Accuracy: 1.0000
Epoch [15/100], Loss: 0.1565, Accuracy: 1.0000
Epoch [16/100], Loss: 0.1552, Accuracy: 1.0000
Epoch [17/100], Loss: 0.1538, Accuracy: 1.0000
Epoch [18/100], Loss: 0.1525, Accuracy: 1.0000
Epoch [19/100], Loss: 0.1512, Accuracy: 1.0000
Epoch [20/100], Loss: 0.1500, Accuracy: 1.0000
Epoch [21/100], Loss: 0.1487, Accuracy: 1.0000
Epoch [22/100], Loss: 

Testing: 100%|██████████| 1/1 [00:00<00:00, 50.72it/s]

Batch 0/1, Current Test Loss: 0.0898, Current Test Accuracy: 1.0000
Predictions: [[0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Targets: [[0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Test Loss: 0.0898, Test Accuracy: 1.0000





(0.0897650495171547, 1.0)

In [90]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(resnet_functional_group.parameters(), lr=0.01, momentum=0.9)

train_losses, train_accuracies, val_losses, val_accuracies = train_model(resnet_functional_group, train_functional_group_loader, val_functional_group_loader, criterion, optimizer, num_epochs=5, log_interval=10)

Epoch [1/5]
Train Batch [0/219], Loss: 0.1858, Accuracy: 0.2869
Train Batch [10/219], Loss: 0.1940, Accuracy: 0.2672
Train Batch [20/219], Loss: 0.1936, Accuracy: 0.2619
Train Batch [30/219], Loss: 0.1956, Accuracy: 0.2586
Train Batch [40/219], Loss: 0.1961, Accuracy: 0.2581
Train Batch [50/219], Loss: 0.1956, Accuracy: 0.2611
Train Batch [60/219], Loss: 0.1947, Accuracy: 0.2634
Train Batch [70/219], Loss: 0.1950, Accuracy: 0.2650
Train Batch [80/219], Loss: 0.1942, Accuracy: 0.2669
Train Batch [90/219], Loss: 0.1937, Accuracy: 0.2693
Train Batch [100/219], Loss: 0.1927, Accuracy: 0.2721
Train Batch [110/219], Loss: 0.1922, Accuracy: 0.2732
Train Batch [120/219], Loss: 0.1915, Accuracy: 0.2743
Train Batch [130/219], Loss: 0.1914, Accuracy: 0.2748
Train Batch [140/219], Loss: 0.1907, Accuracy: 0.2757
Train Batch [150/219], Loss: 0.1905, Accuracy: 0.2766
Train Batch [160/219], Loss: 0.1904, Accuracy: 0.2773
Train Batch [170/219], Loss: 0.1904, Accuracy: 0.2782
Train Batch [180/219], Loss

In [91]:
 test_model(resnet_functional_group, test_functional_group_loader, criterion)

Testing:   0%|          | 1/218 [00:00<00:59,  3.67it/s]

Batch 0/218, Current Test Loss: 0.1983, Current Test Accuracy: 0.4472
Predictions: [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0

Testing:  23%|██▎       | 51/218 [00:15<00:38,  4.32it/s]

Batch 50/218, Current Test Loss: 0.1843, Current Test Accuracy: 0.3133
Predictions: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 

Testing:  46%|████▋     | 101/218 [00:26<00:28,  4.06it/s]

Batch 100/218, Current Test Loss: 0.1968, Current Test Accuracy: 0.3658
Predictions: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.

Testing:  69%|██████▉   | 151/218 [00:37<00:15,  4.20it/s]

Batch 150/218, Current Test Loss: 0.2010, Current Test Accuracy: 0.3841
Predictions: [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.

Testing:  92%|█████████▏| 201/218 [00:48<00:03,  4.76it/s]

Batch 200/218, Current Test Loss: 0.2023, Current Test Accuracy: 0.3963
Predictions: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.

Testing: 100%|██████████| 218/218 [00:52<00:00,  4.16it/s]

Test Loss: 0.2029, Test Accuracy: 0.3978





(0.20289358482995165, 0.39779027675746276)