In [11]:
import json
import torch
from nmr2struct.spec2frags_model import Spec2FragsDataset, Transformer
from tqdm import tqdm

In [12]:
torch.cuda.is_available()

True

In [13]:
data = json.load(open('nmr2struct/data/processed_spec_data.json', 'r'))

In [14]:
train_val_idxs = json.load(open('nmr2struct/train_test_split.json'))

In [15]:
dataset = Spec2FragsDataset(data=data)

100%|█████████████████████████████████████████████████████████| 1332709/1332709 [04:57<00:00, 4483.17it/s]


In [16]:
specs_vocab_size = 300

SPECS_BOS_TOKEN = specs_vocab_size
SPECS_EOS_TOKEN = specs_vocab_size + 1
SPECS_PAD_TOKEN = specs_vocab_size + 2

# Correct
frags_vocab_size = 5113
FRAGS_BOS_TOKEN = frags_vocab_size
FRAGS_EOS_TOKEN = frags_vocab_size + 1
FRAGS_PAD_TOKEN = frags_vocab_size + 2

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
model = Transformer(
    src_vocab_size=specs_vocab_size + 3,
    tgt_vocab_size=frags_vocab_size + 3,
    d_model=256,
    num_heads=8,
    num_layers=4,
    d_ff=256,
    max_seq_length=100,
    dropout=0.2
)

checkpoint = 'nmr2struct_up_propper_fixed_frags_vocab_10'

model.load_state_dict(torch.load(f'nmr2struct/checkpoints/{checkpoint}.pt', map_location = device))

  model.load_state_dict(torch.load(f'nmr2struct/checkpoints/{checkpoint}.pt', map_location = device))


<All keys matched successfully>

In [19]:
@torch.inference_mode
def generate(
    src : torch.Tensor
) -> torch.Tensor:
    tokens = [FRAGS_BOS_TOKEN]
    while len(tokens) < 100 and tokens[-1] != FRAGS_EOS_TOKEN:
        tokens.append(
            model(
                src=src.unsqueeze(0).to(device),
                tgt=torch.tensor([tokens], dtype=torch.int64).to(device)
            )[0, -1, :].argmax().item()
        )
    return torch.tensor(tokens[1:-1], dtype=torch.int64)

In [20]:
frags_comparison = {'id':[], 'generated' : [], 'actual' : []}
for i in tqdm(train_val_idxs['val_idxs']):
    frags_comparison['id'].append(i)
    frags_comparison['generated'].append(generate(dataset[i][0]))
    frags_comparison['actual'].append(dataset[i][1][1:-1])

100%|███████████████████████████████████████████████████████████| 266542/266542 [4:07:56<00:00, 17.92it/s]


In [21]:
import pickle

In [22]:
with open(f'nmr2struct/data/{checkpoint}_frags_comparison.pkl', 'wb') as handle:
    pickle.dump(frags_comparison, handle, protocol=pickle.HIGHEST_PROTOCOL)