In [None]:
# Figure 9: Attention visualizations for the CoCa model
import torch.utils.data as data_utils
import os
import sys

sys.path.append("../../")
from utils.dataset import MyDataset
from utils.tokenizer import Tokenizer
from utils.vit import SimpleViT
from utils.model import CoCa

# Setup
device = 'cpu'
batch_size = 256
max_length = 23
DATA_DIR = '../../data/'

# Make tokenizer
tokenizer = Tokenizer().load_vocab(os.path.join(DATA_DIR, 'vocab.json'))

# Make datasets
train_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'train', tokenizer, max_length)
val_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'val', tokenizer, max_length)
test_dataset = MyDataset(os.path.join(DATA_DIR, 'split.csv'), 'test', tokenizer, max_length)

# Make dataloaders
train_dataloader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data_utils.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Make model
vit = SimpleViT(
        seq_len = 2000,
        patch_size = 40,
        num_classes = None,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 1024,
        channels = 1,
        dim_head = 64
    ).to(device)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json
import os
import torch
from rdkit import Chem
from rdkit.Chem import Draw

indices = [1191, 1028, 368]
contrastive_loss_weight = 0.5
caption_loss_weight = 1.0 

RESULTS_DIR = os.path.join('../models/', f'contrastive{str(contrastive_loss_weight)}')

with open(os.path.join(RESULTS_DIR, 'accs.json'), 'r') as f:
    val_accs = json.load(f)['val']
max_epoch = val_accs.index(max(val_accs))

model = CoCa(
    dim = 512,
    img_encoder = vit,
    image_dim = 1024,
    num_tokens = len(tokenizer.vocab),
    unimodal_depth = 6,
    multimodal_depth = 9,
    dim_head = 64,
    heads = 8,
    caption_loss_weight = caption_loss_weight,
    contrastive_loss_weight = contrastive_loss_weight,
).to(device)

# Load max epoch model
model.load_state_dict(torch.load(os.path.join(RESULTS_DIR, f'epoch_{max_epoch}.pt'), map_location='cpu'))
print(f'Loaded epoch {max_epoch + 1} model with val accuracy {max(val_accs)}')

fig, axes = plt.subplots(3, 3, figsize=(14, 7)) 
for i, idx in enumerate(indices):
    spectrum, smiles = test_dataset[idx]
    
    # Add a batch dimension for processing
    spectrum = spectrum.unsqueeze(0)
    smiles = smiles.unsqueeze(0)
    
    # Decode the SMILES string
    decoded_smiles = tokenizer.decode(smiles[0])
    
    # Obtain model outputs including attention
    loss, accuracy, outputs, query_attn, token_attn = model(
        tokenizer, text=smiles, images=spectrum, return_loss=True, return_attention=True, layer=8
    )

    # Compute the average attention across all heads
    query_head = query_attn.mean(dim=1).squeeze()
    token_head = token_attn.mean(dim=1).squeeze()
    token_head = token_head[1:]  # Skip the first token

    # Compute combined attention scores
    head = torch.matmul(query_head, token_head)

    # Adjust the attention map to the length of the SMILES string
    smiles_length = len(decoded_smiles)
    attention_map = np.flipud(head.cpu().detach().numpy()[:smiles_length, :])

    # Plot the spectrum
    ax_spectrum = axes[0, i]
    spectrum_data = spectrum.squeeze()
    ax_spectrum.plot(range(len(spectrum_data)), spectrum_data)
    ax_spectrum.invert_yaxis() 
    ax_spectrum.invert_xaxis()  

    x_ticks = np.arange(0, 2001, 500) 
    x_labels = [f"{int(round(value * 2))}" for value in x_ticks]
    ax_spectrum.set_xticks(x_ticks)
    ax_spectrum.set_xticklabels(x_labels)

    # Plot the attention map
    ax_attention = axes[1, i]
    im = ax_attention.imshow(attention_map, cmap='hot', aspect='equal')
    ax_attention.invert_xaxis() 
    ax_attention.set_yticks(np.arange(len(decoded_smiles)))
    ax_attention.set_yticklabels(list(decoded_smiles)[::-1], fontsize=10, rotation=90)

    # Display the SMILES as an image
    mol = Chem.MolFromSmiles(decoded_smiles)
    mol_image = Draw.MolToImage(mol, size=(200, 200))
    ax_smiles = axes[2, i]
    ax_smiles.imshow(mol_image)
    ax_smiles.axis('off')

# Create a color bar in a separate axes
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.tight_layout(rect=[0, 0, 0.9, 1])  
plt.show()