In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/NLP'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/NLP


In [2]:
import torch
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import v_measure_score
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

In [3]:
embeddings = torch.load("bible_embeddings.pt")
means = torch.load("bible_similar_mean.pt")

In [4]:
data = pd.read_csv("bibledata.csv")

In [5]:
styles = embeddings - means
target = F.normalize(styles)

In [6]:
l1 = 128
l2 = 64
l3 = 32
z_dim = 16

In [21]:
class L2Norm(nn.Module):
    def __init__(self, dim=1, eps=1e-12):
        super(L2Norm, self).__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        self.dim = len(x.shape) - 1
        return F.normalize(x, p=2, dim=self.dim, eps=self.eps)

In [22]:
def Encoder(x_dim, z_dim):
  model = nn.Sequential(
    nn.Linear(x_dim, l1),
    nn.ReLU(),
    nn.Linear(l1, l2),
    nn.ReLU(),
    nn.Linear(l2, l3),
    nn.ReLU(),
    nn.Linear(l3, z_dim))
  return model

In [23]:
def Decoder(x_dim, z_dim):
  model = nn.Sequential(
  nn.Linear(z_dim, l3),
    nn.ReLU(),
    nn.Linear(l3, l2),
    nn.ReLU(),
    nn.Linear(l2, l1),
    nn.ReLU(),
    nn.Linear(l1, x_dim),
    L2Norm())
  return model

In [24]:
class Model(nn.Module):
  def __init__(self, x_dim, z_dim):
      super(Model, self).__init__()
      self.Encoder = Encoder(x_dim, z_dim)
      self.Decoder = Decoder(x_dim, z_dim)
  
  def forward(self, x):
      z = self.Encoder(x)
      out = self.Decoder(z)
        
      return out

In [25]:
epochs = 30
batch_size = 64

In [12]:
dataset = TensorDataset(embeddings, target)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last = True)

In [26]:
model = Model(384, z_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [27]:
for epoch in tqdm(range(epochs)):
    overall_loss = 0
    for batch_idx, (x, label) in enumerate(train_dataloader):
        x = x.view(batch_size, 384)

        optimizer.zero_grad()

        predict = model(x)
        loss = torch.sum(torch.square(1 - F.cosine_similarity(predict, label)))
      
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    validation_loss = torch.sum(torch.square(1 - F.cosine_similarity(model(embeddings[-43:]), target[-43:]))) 
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*64), "\tValidation Loss:", validation_loss / 43)

  3%|▎         | 1/30 [00:01<00:29,  1.02s/it]

	Epoch 1 complete! 	Average Loss:  0.8754091155098145 	Validation Loss: tensor(0.7545, grad_fn=<DivBackward0>)


  7%|▋         | 2/30 [00:02<00:28,  1.01s/it]

	Epoch 2 complete! 	Average Loss:  0.7150739521147257 	Validation Loss: tensor(0.6909, grad_fn=<DivBackward0>)


 10%|█         | 3/30 [00:03<00:27,  1.03s/it]

	Epoch 3 complete! 	Average Loss:  0.6577289941081081 	Validation Loss: tensor(0.6443, grad_fn=<DivBackward0>)


 13%|█▎        | 4/30 [00:04<00:26,  1.01s/it]

	Epoch 4 complete! 	Average Loss:  0.6360573104347091 	Validation Loss: tensor(0.6267, grad_fn=<DivBackward0>)


 17%|█▋        | 5/30 [00:05<00:25,  1.04s/it]

	Epoch 5 complete! 	Average Loss:  0.616175964295146 	Validation Loss: tensor(0.5983, grad_fn=<DivBackward0>)


 20%|██        | 6/30 [00:06<00:26,  1.11s/it]

	Epoch 6 complete! 	Average Loss:  0.5965381347271334 	Validation Loss: tensor(0.5790, grad_fn=<DivBackward0>)


 23%|██▎       | 7/30 [00:07<00:26,  1.15s/it]

	Epoch 7 complete! 	Average Loss:  0.579375434711755 	Validation Loss: tensor(0.5481, grad_fn=<DivBackward0>)


 27%|██▋       | 8/30 [00:08<00:25,  1.16s/it]

	Epoch 8 complete! 	Average Loss:  0.5632308936980833 	Validation Loss: tensor(0.5466, grad_fn=<DivBackward0>)


 30%|███       | 9/30 [00:10<00:25,  1.20s/it]

	Epoch 9 complete! 	Average Loss:  0.548909062362579 	Validation Loss: tensor(0.5366, grad_fn=<DivBackward0>)


 33%|███▎      | 10/30 [00:11<00:27,  1.35s/it]

	Epoch 10 complete! 	Average Loss:  0.5357357464640974 	Validation Loss: tensor(0.5170, grad_fn=<DivBackward0>)


 37%|███▋      | 11/30 [00:13<00:27,  1.45s/it]

	Epoch 11 complete! 	Average Loss:  0.5200248184692429 	Validation Loss: tensor(0.4869, grad_fn=<DivBackward0>)


 40%|████      | 12/30 [00:14<00:26,  1.47s/it]

	Epoch 12 complete! 	Average Loss:  0.5064069632306156 	Validation Loss: tensor(0.4815, grad_fn=<DivBackward0>)


 43%|████▎     | 13/30 [00:16<00:23,  1.39s/it]

	Epoch 13 complete! 	Average Loss:  0.4950457059116249 	Validation Loss: tensor(0.4664, grad_fn=<DivBackward0>)


 47%|████▋     | 14/30 [00:17<00:21,  1.32s/it]

	Epoch 14 complete! 	Average Loss:  0.4860211169145193 	Validation Loss: tensor(0.4572, grad_fn=<DivBackward0>)


 50%|█████     | 15/30 [00:18<00:19,  1.27s/it]

	Epoch 15 complete! 	Average Loss:  0.47687109748283063 	Validation Loss: tensor(0.4497, grad_fn=<DivBackward0>)


 53%|█████▎    | 16/30 [00:19<00:17,  1.23s/it]

	Epoch 16 complete! 	Average Loss:  0.47011219234351653 	Validation Loss: tensor(0.4365, grad_fn=<DivBackward0>)


 57%|█████▋    | 17/30 [00:20<00:15,  1.22s/it]

	Epoch 17 complete! 	Average Loss:  0.46140274872262793 	Validation Loss: tensor(0.4211, grad_fn=<DivBackward0>)


 60%|██████    | 18/30 [00:22<00:14,  1.23s/it]

	Epoch 18 complete! 	Average Loss:  0.45556845973773175 	Validation Loss: tensor(0.4135, grad_fn=<DivBackward0>)


 63%|██████▎   | 19/30 [00:23<00:13,  1.22s/it]

	Epoch 19 complete! 	Average Loss:  0.4491975329008447 	Validation Loss: tensor(0.4077, grad_fn=<DivBackward0>)


 67%|██████▋   | 20/30 [00:24<00:11,  1.20s/it]

	Epoch 20 complete! 	Average Loss:  0.44253339196544095 	Validation Loss: tensor(0.4043, grad_fn=<DivBackward0>)


 70%|███████   | 21/30 [00:26<00:12,  1.35s/it]

	Epoch 21 complete! 	Average Loss:  0.4373745848256421 	Validation Loss: tensor(0.4089, grad_fn=<DivBackward0>)


 73%|███████▎  | 22/30 [00:27<00:11,  1.44s/it]

	Epoch 22 complete! 	Average Loss:  0.43149939251233294 	Validation Loss: tensor(0.4060, grad_fn=<DivBackward0>)


 77%|███████▋  | 23/30 [00:29<00:10,  1.46s/it]

	Epoch 23 complete! 	Average Loss:  0.4269728692899267 	Validation Loss: tensor(0.3952, grad_fn=<DivBackward0>)


 80%|████████  | 24/30 [00:30<00:08,  1.39s/it]

	Epoch 24 complete! 	Average Loss:  0.4216896239892546 	Validation Loss: tensor(0.3861, grad_fn=<DivBackward0>)


 83%|████████▎ | 25/30 [00:31<00:06,  1.32s/it]

	Epoch 25 complete! 	Average Loss:  0.4175376818481698 	Validation Loss: tensor(0.3804, grad_fn=<DivBackward0>)


 87%|████████▋ | 26/30 [00:32<00:05,  1.27s/it]

	Epoch 26 complete! 	Average Loss:  0.4110986285180931 	Validation Loss: tensor(0.3836, grad_fn=<DivBackward0>)


 90%|█████████ | 27/30 [00:34<00:03,  1.25s/it]

	Epoch 27 complete! 	Average Loss:  0.40590632159307777 	Validation Loss: tensor(0.3777, grad_fn=<DivBackward0>)


 93%|█████████▎| 28/30 [00:35<00:02,  1.22s/it]

	Epoch 28 complete! 	Average Loss:  0.40123374957636176 	Validation Loss: tensor(0.3716, grad_fn=<DivBackward0>)


 97%|█████████▋| 29/30 [00:36<00:01,  1.23s/it]

	Epoch 29 complete! 	Average Loss:  0.3969596814678376 	Validation Loss: tensor(0.3678, grad_fn=<DivBackward0>)


100%|██████████| 30/30 [00:37<00:00,  1.25s/it]

	Epoch 30 complete! 	Average Loss:  0.39259055167077533 	Validation Loss: tensor(0.3678, grad_fn=<DivBackward0>)





In [29]:
x0 = embeddings[10730]
y0 = target[10730]
yhat = model(x0)
torch.norm(yhat)

tensor(1.0000, grad_fn=<LinalgVectorNormBackward0>)

In [30]:
y0.dot(yhat)

tensor(0.4326, grad_fn=<DotBackward0>)

In [31]:
zs = torch.zeros((10731, z_dim))
for i in tqdm(range(len(embeddings))):
  zs[i] = model.Encoder(embeddings[i])

100%|██████████| 10731/10731 [00:03<00:00, 2741.98it/s]


In [32]:
km = KMeans(n_clusters = 7)
km.fit(zs.detach())
data["labels"] = km.labels_



In [33]:
v_measure_score(data.id, data.labels)

0.30002697504003606