In [5]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tone_embeddings_dataset import ToneContoursDataset
from tone_embeddings_model import ToneFeedForward
from id_word_map import IdWordMap
from splits import splits
wordMap = IdWordMap()
n_files = len(wordMap.ids)
print(n_files)

# Settings:
epochs = 40
currentSplit = 0
wordTone = False
outputsForConfusion = False

718


In [6]:
dataset = ToneContoursDataset(n_files, split = currentSplit, word = wordTone)
dataloader = DataLoader(dataset, batch_size = 10, shuffle = True)

In [7]:
model = ToneFeedForward()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device, dtype=torch.float32)

ToneFeedForward(
  (net): Sequential(
    (0): Linear(in_features=50, out_features=512, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=512, out_features=2048, bias=True)
    (3): Sigmoid()
    (4): Linear(in_features=2048, out_features=3072, bias=True)
  )
)

In [8]:
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for x_batch, y_batch in dataloader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = loss_fn(outputs, y_batch)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * x_batch.size(0)
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataset):.4f}")

Epoch 1/40, Loss: 0.0244
Epoch 2/40, Loss: 0.0005
Epoch 3/40, Loss: 0.0004
Epoch 4/40, Loss: 0.0004
Epoch 5/40, Loss: 0.0004
Epoch 6/40, Loss: 0.0004
Epoch 7/40, Loss: 0.0004
Epoch 8/40, Loss: 0.0004
Epoch 9/40, Loss: 0.0004
Epoch 10/40, Loss: 0.0004
Epoch 11/40, Loss: 0.0004
Epoch 12/40, Loss: 0.0004
Epoch 13/40, Loss: 0.0004
Epoch 14/40, Loss: 0.0004
Epoch 15/40, Loss: 0.0004
Epoch 16/40, Loss: 0.0004
Epoch 17/40, Loss: 0.0004
Epoch 18/40, Loss: 0.0004
Epoch 19/40, Loss: 0.0004
Epoch 20/40, Loss: 0.0004
Epoch 21/40, Loss: 0.0004
Epoch 22/40, Loss: 0.0004
Epoch 23/40, Loss: 0.0004
Epoch 24/40, Loss: 0.0004
Epoch 25/40, Loss: 0.0004
Epoch 26/40, Loss: 0.0004
Epoch 27/40, Loss: 0.0004
Epoch 28/40, Loss: 0.0004
Epoch 29/40, Loss: 0.0004
Epoch 30/40, Loss: 0.0004
Epoch 31/40, Loss: 0.0004
Epoch 32/40, Loss: 0.0004
Epoch 33/40, Loss: 0.0004
Epoch 34/40, Loss: 0.0004
Epoch 35/40, Loss: 0.0004
Epoch 36/40, Loss: 0.0004
Epoch 37/40, Loss: 0.0004
Epoch 38/40, Loss: 0.0004
Epoch 39/40, Loss: 0.

In [9]:
# build the stack (vector database essentially)

tensors_to_stack = []
for i in range(n_files):
    tensors_to_stack.append(torch.tensor(np.load(f'v_context_embeddings/{i}.npy').astype(np.float32), dtype=torch.float32))
meaning_stack = torch.stack(tensors_to_stack)


In [10]:
model.eval()
train_successes = 0
train_list = sorted(list({i for i in range(n_files)} - set(splits[currentSplit])))
directory = 'v_tone_embeddings_word' if wordTone else 'v_tone_embeddings'
confusionTrue, confusionPredicted = [], []
# train accuracy:
for i in train_list:
    x_test = torch.tensor(np.load(f'{directory}/{i}.npy').astype(np.float32), dtype=torch.float32)
    correct_word = wordMap[i]
    with torch.no_grad():
        prediction = model(x_test)
    closest = torch.norm(meaning_stack - prediction.unsqueeze(0), dim = 1).argmin()
    if correct_word == wordMap[closest]: train_successes += 1

print(f'Train Accuracy: {train_successes / len(train_list)}')

# test accuracy:
test_successes = 0
for i in splits[currentSplit]:
    x_test = torch.tensor(np.load(f'{directory}/{i}.npy').astype(np.float32), dtype=torch.float32)
    correct_word = wordMap[i]
    with torch.no_grad():
        prediction = model(x_test)
    closest = torch.norm(meaning_stack - prediction.unsqueeze(0), dim = 1).argmin()
    closest_word = wordMap[closest]
    if outputsForConfusion:
        confusionTrue.append(correct_word)
        confusionPredicted.append(closest_word)
    if correct_word == closest_word: test_successes += 1

print(f'Test Accuracy: {test_successes / len(splits[currentSplit])}')
if outputsForConfusion: 
    print(confusionTrue, '\n', confusionPredicted)
    from seaborn import countplot
    countplot(confusionPredicted)


Train Accuracy: 0.21428571428571427
Test Accuracy: 0.1875
