### Train model

In [10]:
import torch
import torch.optim as optim
import torch.utils.data as data_utils
import json
import matplotlib.pyplot as plt

import os
import sys
from utils.dataset import MyDataset
from utils.tokenizer import Tokenizer
from utils.vit import SimpleViT
from utils.model import CoCa
from tqdm import tqdm

# Setup
device = 'cpu'
batch_size = 256
lr = 1e-4
num_epochs = 2
max_length = 23
contrastive_loss_weight = 0.0
caption_loss_weight = 1.0
DATA_DIR = '../data/'
RESULTS_DIR = os.path.join('./models_temp/', f'contrastive{str(contrastive_loss_weight)}')

# Make tokenizer
tokenizer = Tokenizer().load_vocab(os.path.join(DATA_DIR, 'vocab.json'))

# Make datasets
train_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'train', tokenizer, max_length)
val_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'val', tokenizer, max_length)
test_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'test', tokenizer, max_length)

# Make model
vit = SimpleViT(
        seq_len = 2000,
        patch_size = 40,
        num_classes = None,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 1024,
        channels = 1,
        dim_head = 64
    ).to(device)

model = CoCa(
    dim = 512,
    img_encoder = vit,
    image_dim = 1024,
    num_tokens = len(tokenizer.vocab),
    unimodal_depth = 6,
    multimodal_depth = 9,
    dim_head = 64,
    heads = 8,
    caption_loss_weight = caption_loss_weight,
    contrastive_loss_weight = contrastive_loss_weight,
).to(device)

# Make dataloaders
train_dataloader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data_utils.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Training Loop
if not os.path.exists(RESULTS_DIR):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    all_train_losses = []
    all_train_accs = []
    all_val_losses = []
    all_val_accs = []
    for epoch in range(num_epochs):

        ## Train model
        model.train()
        train_losses, train_accs, train_outputs = [], [], []
        for spectrum, smiles in tqdm(train_dataloader):
            spectrum, smiles = spectrum.to(device), smiles.to(device)
            optimizer.zero_grad()
            loss, accuracy, outputs = model(tokenizer, text = smiles, images = spectrum, return_loss = True)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            train_accs.append(accuracy)
            train_outputs.extend(outputs)
        print(f'Epoch: {epoch}, Loss: {sum(train_losses)/len(train_losses)}, Accuracy: {sum(train_accs)/len(train_accs)}')

        ## Evaluate model
        model.eval()
        val_losses = []
        val_accs = []
        val_outputs = []
        with torch.no_grad():
            for spectrum, smiles in val_dataloader:
                spectrum, smiles = spectrum.to(device), smiles.to(device)
                loss, accuracy, output = model(tokenizer, text = smiles, images = spectrum, return_loss = True)
                val_losses.append(loss.item())
                val_accs.append(accuracy)
                val_outputs.extend(output)
        print(f'Epoch: {epoch}, Val Loss: {sum(val_losses)/len(val_losses)}, Val Accuracy: {sum(val_accs)/len(val_accs)}')

        # Save model checkpoint if more than halfway done with epochs
        if epoch >= num_epochs // 2:
            os.makedirs(os.path.join(RESULTS_DIR, 'checkpoints'), exist_ok=True)
            torch.save(model.state_dict(), os.path.join(RESULTS_DIR, 'checkpoints', f'epoch_{epoch}.pt'))

        # Save outputs to JSON
        os.makedirs(os.path.join(RESULTS_DIR, 'outputs'), exist_ok=True)
        with open(os.path.join(RESULTS_DIR, 'outputs', f'train_outputs_epoch_{epoch}.json'), 'w') as f:
            json.dump(train_outputs, f)
        with open(os.path.join(RESULTS_DIR, 'outputs', f'val_outputs_epoch_{epoch}.json'), 'w') as f:
            json.dump(val_outputs, f)

        # Save losses and accuracies
        all_train_losses.append(sum(train_losses)/len(train_losses))
        all_train_accs.append(sum(train_accs)/len(train_accs))
        all_val_losses.append(sum(val_losses)/len(val_losses))
        all_val_accs.append(sum(val_accs)/len(val_accs))

        with open(os.path.join(RESULTS_DIR, 'losses.json'), 'w') as f:
            json.dump({'train': all_train_losses, 'val': all_val_losses}, f)
        with open(os.path.join(RESULTS_DIR, 'accs.json'), 'w') as f:
            json.dump({'train': all_train_accs, 'val': all_val_accs}, f)


    fig, ax = plt.subplots(1, 2, figsize=(15, 5))
    ax[0].plot(all_train_losses)
    ax[0].plot(all_val_losses)
    ax[0].set_title('Loss')
    ax[0].legend(['train', 'val'])
    ax[1].plot(all_train_accs)
    ax[1].plot(all_val_accs)
    ax[1].set_title('Accuracy')
    ax[1].legend(['train', 'val'])
    plt.show();

    # Save plots
    fig.savefig(os.path.join('models/', f'contrastive{str(contrastive_loss_weight)}', 'history.png'))
    plt.close(fig)

else:
    print('Model already trained')

Model already trained


### Test model

In [None]:
# Find the epoch with the highest val accuracy
with open(os.path.join(RESULTS_DIR, 'accs.json'), 'r') as f:
    val_accs = json.load(f)['val']
max_epoch = val_accs.index(max(val_accs))

# Load max epoch model
model.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'checkpoints', f'epoch_{max_epoch}.pt'), map_location='cpu'))
print(f'Loaded epoch {max_epoch + 1} model with val accuracy {max(val_accs)}') # +1 because epoch 0 is the first epoch

# Set seeds
np.random.seed(7)
torch.manual_seed(7)

# Evaluate model on test set
test_accs = {}
for k in [1, 5, 10, 20]:
    # Generate outputs
    outputs = []
    for spectrum, smiles in tqdm(test_dataloader):
        spectrum, smiles = spectrum.to(device), smiles.to(device)
        output = model.generate_autoregressively(tokenizer, spectrum, smiles, max_length, num_attempts=k)
        outputs.extend(output)
    
    # Calculate accuracy
    num_correct = 0
    for output in outputs:
        if any([output['predicted'][i] == output['original'] for i in range(k)]):
            num_correct += 1
    acc = num_correct / len(outputs)

    path_to_results = os.path.join(RESULTS_DIR, 'outputs', f'test_outputs_{k}_attempts.json')
    with open(path_to_results, 'w') as f:
        json.dump(outputs, f, indent=4)

    test_accs[f'{k}_attempts'] = acc

# Save test accuracies
with open(os.path.join(RESULTS_DIR, 'accs.json'), 'r') as f:
    accs = json.load(f)
accs['test'] = test_accs
with open(os.path.join(RESULTS_DIR, 'accs.json'), 'w') as f:
    json.dump(accs, f, indent=4)

### Get embeddings from last epoch

In [1]:
import torch
import torch.utils.data as data_utils

import os
import sys
from utils.dataset import MyDataset
from utils.tokenizer import Tokenizer
from utils.vit import SimpleViT
from utils.model import CoCa
from tqdm import tqdm
import numpy as np

# Setup
device = 'cpu'
batch_size = 256
lr = 1e-4
num_epochs = 2
max_length = 23
contrastive_loss_weight = 0.0
caption_loss_weight = 1.0
DATA_DIR = '../data/'

for contrastive_loss_weight in ['0.0', '0.1', '0.5', '1.0']:
    RESULTS_DIR = os.path.join('./models/', f'contrastive{str(contrastive_loss_weight)}')

    # Make tokenizer
    tokenizer = Tokenizer().load_vocab(os.path.join(DATA_DIR, 'vocab.json'))

    # Make datasets
    torch.manual_seed(7)  # All the models should have the same order of molecules in the dataset so we can plot stuff later
    test_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'test', tokenizer, max_length)
    dataloader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    # Load model
    vit = SimpleViT(
            seq_len = 2000,
            patch_size = 40,
            num_classes = None,
            dim = 1024,
            depth = 6,
            heads = 8,
            mlp_dim = 1024,
            channels = 1,
            dim_head = 64
        ).to(device)

    model = CoCa(
        dim = 512,
        img_encoder = vit,
        image_dim = 1024,
        num_tokens = len(tokenizer.vocab),
        unimodal_depth = 6,
        multimodal_depth = 9,
        dim_head = 64,
        heads = 8,
        caption_loss_weight = caption_loss_weight,
        contrastive_loss_weight = contrastive_loss_weight,
    ).to(device)
    
    model.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'checkpoints', f'epoch_99.pt'), map_location='cpu'))

    # Get embeddings
    smiles_list = []
    all_smiles_embeddings = []
    all_spectrum_embeddings = []
    for spectrum, smiles in tqdm(dataloader):
        spectrum, smiles = spectrum.to(device), smiles.to(device)
        smiles_embeddings, spectrum_embeddings = model.forward(tokenizer, text = smiles, images = spectrum, return_embeddings = True)
        all_smiles_embeddings.append(smiles_embeddings)
        all_spectrum_embeddings.append(spectrum_embeddings)
        smiles_list.extend([tokenizer.decode(smile) for smile in smiles])

    # Concatenate embeddings
    all_smiles_embeddings = torch.cat(all_smiles_embeddings, dim=0)
    all_spectrum_embeddings = torch.cat(all_spectrum_embeddings, dim=0)

    # Save embeddings
    os.makedirs(os.path.join(RESULTS_DIR, 'embeddings'), exist_ok=True)
    torch.save(all_smiles_embeddings, os.path.join(RESULTS_DIR, 'embeddings', 'smiles_embeddings.pt'))
    torch.save(all_spectrum_embeddings, os.path.join(RESULTS_DIR, 'embeddings', 'spectrum_embeddings.pt'))

    # Save smiles list
    with open(os.path.join(RESULTS_DIR, 'embeddings', 'smiles.txt'), 'w') as f:
        f.write('\n'.join(smiles_list))

100%|██████████| 7/7 [00:02<00:00,  2.53it/s]
100%|██████████| 7/7 [00:02<00:00,  2.55it/s]
100%|██████████| 7/7 [00:02<00:00,  2.49it/s]
100%|██████████| 7/7 [00:02<00:00,  2.50it/s]
