In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from transformers import BertModel
from datasets import load_dataset

In [None]:
import numpy as np
from scipy.io import loadmat

harry_potter = loadmat('data/brain_data/subject_1.mat')


words = []
for i in range(5176):
    word = harry_potter['words'][0][i][0][0][0][0]
    words.append(word)

word_times = []
for i in range(5176):
    word_time = harry_potter['words'][0][i][1][0][0]
    word_times.append(word_time)

tr_times = []
for i in range(1351):
    tr_time = harry_potter['time'][i,0]
    tr_times.append(tr_time)

dont_include_indices = [i for i in range(15)] + [i for i in range(335,355)] + [i for i in range(687,707)] + [i for i in range(966,986)] + [i for i in range(1346,1351)]

X_fmri = harry_potter['data']

useful_X_fmri = np.delete(X_fmri, dont_include_indices,axis=0)

tr_times_arr = np.asarray(tr_times)

useful_tr_times = np.delete(tr_times_arr, dont_include_indices)

sentences = [[]]*1271
for idx, useful_tr_time in enumerate(useful_tr_times):
    sentence= []
    for word, word_time in zip(words,word_times):
        if useful_tr_time - 10 <= word_time <= useful_tr_time:
            sentence.append(word)
    sentences[idx] = sentence   

In [None]:
actual_sentences = ['']*1271
for idx, sentence in enumerate(sentences):
    for word in sentence:
        actual_sentences[idx] = actual_sentences[idx] + word + ' '

In [None]:
useful_X_fmri.shape

In [None]:
class BrainBiasedBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.linear = nn.Linear(768,37913)
    def forward(self, x):
        embeddings = self.tokenizer(x, return_tensors='pt', padding=True)
        representations = self.bert(**embeddings).last_hidden_state
        cls_representation = representations[:,0,:]
        pred_fmri = self.linear(cls_representation)
        return pred_fmri

In [None]:
model = BrainBiasedBERT()

In [None]:
pred_fmri = model(actual_sentences[:5])
pred_fmri.shape

In [None]:
fmri = torch.as_tensor(useful_X_fmri)
truth_fmri = fmri[:5,:]
truth_fmri.shape

In [None]:
loss_function = nn.MSELoss()

In [None]:
loss_function(pred_fmri, truth_fmri)

In [None]:
from torch.utils.data import DataLoader

dataset = []
for i in range(1271):
    dataset.append((actual_sentences[i], fmri[i,:]))
    
#TRAIN TEST SPLIT HAS OVERLAP IN WORDS AND IN BRAIN STATE
n_rows = len(dataset)
train_dataset = dataset[:int(.8*n_rows)]
val_dataset = dataset[int(.8*n_rows):]

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
full_num_epochs = 15
loss_over_time = []
time = []
import matplotlib.pyplot as plt

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
for epoch in range(full_num_epochs):
    epoch += 1
    for batch_idx, (data, targets) in enumerate(train_dataloader):
        #print(data[0])
        #print(targets)
        preds = model(data[0])
        loss = loss_function(preds, targets.float())
        loss_over_time.append(loss.item())
        time.append(batch_idx*epoch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    plt.plot(time, loss_over_time)
    plt.show()

In [None]:
#model.eval()
#with torch.no_grad():
#    test_losses = []
#    for x, y in test_dataloader:
#        preds = model(x[0])
#        test_loss = loss_function(preds,y.float())
#        test_losses.append(test_loss)
        
#print(torch.mean(torch.as_tensor(test_losses))) 
#model.train()

In [None]:
torch.save(model.state_dict(), 'state_dict')

# Loading evaluation suite

In [None]:
mnli = load_dataset('multi_nli')