In [3]:
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/NLP'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/NLP


In [4]:
import torch
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import v_measure_score
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
import random
from torch.nn.utils.rnn import pad_sequence
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
data = pd.read_csv("bibledata.csv")

In [7]:
means = torch.load("bible_similar_mean.pt", map_location = device)
embeddings = torch.load("bible_embeddings.pt", map_location = device)

styles = embeddings - means
similarity_matrix = embeddings.matmul(embeddings.T)
top_n = similarity_matrix.sort()[1][:,-7:-1]

In [8]:
bib_tokenizer = get_tokenizer('basic_english')

In [9]:
def bib_tokens(data):
  for text in iter(data):
    yield bib_tokenizer(text)

In [10]:
bib_vocab = build_vocab_from_iterator(bib_tokens(data.text), specials=["<unk>", "<pad>", "<sos>", "<eos>"])

In [11]:
def data_process(text):
  index = 0
  raw_bib_iter = iter(text.text)
  data = []
  for i in raw_bib_iter:
    bib_tensor_ = torch.tensor([bib_vocab[token] for token in bib_tokenizer(i)],
                            dtype=torch.long)
    data.append((bib_tensor_, index))
    index+=1
  return data

In [12]:
table = data_process(data)

In [13]:
max_len = 0
for i in table:
  if len(i) > max_len:
    max_len = len(i)

In [28]:
BATCH_SIZE = 73
PAD_IDX = bib_vocab['<pad>']
SOS_IDX = bib_vocab['<sos>']
EOS_IDX = bib_vocab['<eos>']

random.seed(42)

def generate_batch(data_batch):
  batch, idx = [], []
  for item, i in data_batch:
    batch.append(torch.cat([torch.tensor([SOS_IDX]), item, torch.tensor([EOS_IDX])], dim=0))
    idx.append(i)
  batch = pad_sequence(batch, padding_value=PAD_IDX, batch_first = True)
  return batch, idx

train_iter = DataLoader(table, batch_size=BATCH_SIZE, shuffle=None, collate_fn=generate_batch)

In [16]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [50]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken, d_model, nhead, d_hid,
                 nlayers, dropout= 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first = True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src = self.encoder(src) * np.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        return output.transpose(1, 0)[0]

In [51]:
class Model(nn.Module):
  def __init__(self, out_dim, n_tokens, transformer_dim, n_head, n_layers, dropout = 0.5, d_hid = 1024):
    super().__init__()
    self.transformer = TransformerModel(n_tokens, transformer_dim, n_head, d_hid, n_layers, dropout)
    self.decoder = nn.Linear(transformer_dim + out_dim, out_dim)
  
  def forward(self, x, centre):
    encoded = self.transformer(x)
    out = self.decoder(torch.concatenate([encoded, centre], dim = 1))
    return out

In [52]:
model = Model(384, len(bib_vocab), 64, 16, 6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [53]:
epochs = 20

In [55]:
for epoch in tqdm(range(epochs)):
    overall_loss = 0
    for batch_idx, (x, index) in enumerate(train_iter):
        x = x.view(BATCH_SIZE, -1).to(device)
        #nearest = embeddings[top_n[index]]
        centre = means[index]

        optimizer.zero_grad()
        predict = model(x, centre)
        #nears = model(nearest)

        loss = torch.sum(torch.square(1 - F.cosine_similarity(predict, embeddings[index]))) #+ torch.sum(torch.square(F.cosine_similarity(predict.unsqueeze(1), nears, dim=2)))/6
      
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*BATCH_SIZE))

  5%|▌         | 1/20 [00:05<01:38,  5.20s/it]

	Epoch 1 complete! 	Average Loss:  0.229379263396647


 10%|█         | 2/20 [00:10<01:31,  5.07s/it]

	Epoch 2 complete! 	Average Loss:  0.06536986765661972


 10%|█         | 2/20 [00:14<02:12,  7.35s/it]


KeyboardInterrupt: ignored

In [45]:
zs = np.zeros((len(data), 64))
for batch_idx, (x, i) in enumerate(train_iter):
  zs[i] = (model.transformer(x.to(device)).transpose(1, 0)[0].cpu().detach())
km = KMeans(n_clusters = 7)
km.fit(zs)
data["labels"] = km.labels_



In [46]:
v_measure_score(data.id, data.labels)

0.0012181990890682024