In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import json

In [11]:
train_data = []
train_inp_file = 'processed_data/GCDC/Clinton_train.jsonl'
with open(train_inp_file, 'r') as f:
    for line in f:
        json_obj = json.loads(line)
        # put <EOS> at the end of each sentence and add each sentence to the list
        for i in range(len(json_obj['sentences'])):
            json_obj['sentences'][i].append('<EOS>')
        # merge all sentences into one
        document = " ".join([word for sentence in json_obj['sentences'] for word in sentence])
        document = document.lower()
        label = json_obj['label']
        train_data.append((document, label))
   

test_data = []
test_inp_file = 'processed_data/GCDC/Clinton_test.jsonl'
with open(test_inp_file, 'r') as f:
    for line in f:
        json_obj = json.loads(line)
        # put <EOS> at the end of each sentence and add each sentence to the list
        for i in range(len(json_obj['sentences'])):
            json_obj['sentences'][i].append('<EOS>')
        # merge all sentences into one
        document = " ".join([word for sentence in json_obj['sentences'] for word in sentence])
        document = document.lower()
        label = json_obj['label']
        test_data.append((document, label))
        
print(len(train_data))
print(train_data[0])
print(train_data[1])
print(train_data[2])
print(len(test_data))
print(test_data[0])

# create vocabulary
vocab = set()
for document, label in train_data:
    for word in document.split():
        vocab.add(word.lower())
vocab = list(vocab)
vocab.append('<PAD>')
vocab.append('<UNK>')
print(len(vocab))
print(vocab[:10])


800
('two options the us views the transitional national council as the sole / only legitimate interlocutor of the libyan people during this interim period , as libyans come together to plan their own future and a permanent , inclusive constitutional system that protects the rights of all libyans . <eos> this is in contrast to the qadhafi regime , which has lost all legitimacy to rule . <eos> the us views the transitional national council as the legitimate interlocutor of the libyan people during this interim period , as libyans come together to plan their own future and a permanent , inclusive constitutional system that protects the rights of all libyans . <eos> this is in contrast to the qadhafi regime , which has lost all legitimacy to rule . <eos> the inc is the institution through which we are engaging the libyan people at this time . <eos>', 3)
("ambassador , we just received an email from the adoption service provider about these cases . <eos> i am currently reviewing the files 

In [12]:
# transform the documents into list of indices
def transform_doc(document, vocab):
    indices = []
    for word in document.split():
        if word.lower() in vocab:
            indices.append(vocab.index(word.lower()))
        else:
            indices.append(vocab.index('<UNK>'))
    return indices

# do it for all the sentences
train_data = [(transform_doc(document, vocab), label) for document, label in train_data]
test_data = [(transform_doc(document, vocab), label) for document, label in test_data]
# print(train_data[0])
# print(test_data[0])

# get the max length of the documents and also average length
total_len = 0
max_len = 0
count = 0
for document, label in train_data:
    max_len = max(max_len, len(document))
    total_len += len(document)
    count += 1

print(f'max doc length : {max_len}')
print(f'average doc length : {total_len/count}')

# pad the sentences to make them of same length
def pad_doc(document, max_len):
    if len(document) < max_len:
        document += [vocab.index('<PAD>')] * (max_len - len(document))
    return document

train_data = [(pad_doc(document, max_len), label) for document, label in train_data]
test_data = [(pad_doc(document, max_len), label) for document, label in test_data]
# print(train_data[0])
# print(len(train_data[0][0]))
# print(test_data[0])

# batchify the data after converting to tensors
class Batchify(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx][0]), torch.tensor(self.data[idx][1])
    
train_data = Batchify(train_data)
test_data = Batchify(test_data)

# create dataloaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True)


max doc length : 415
average doc length : 194.9475


In [13]:
# every document has a label : 1 or 2 or 3
# we need a model that can take a document and predict the label and implement it using transformer (basically textual coherence)

# transformer model

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_heads, n_layers, dropout, max_len):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embedding = nn.Embedding(max_len, embedding_dim)
        self.transformer = nn.Transformer(embedding_dim, n_heads, n_layers, dropout)
        self.fc = nn.Linear(embedding_dim, 3)
    
    def forward(self, x):
        # x : (batch_size, seq_len)
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        x = self.embedding(x)
        # x : (batch_size, seq_len, embedding_dim)
        pos = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1)
        # pos : (batch_size, seq_len)
        x = x + self.pos_embedding(pos)
        # x : (batch_size, seq_len, embedding_dim)
        x = x.permute(1, 0, 2)
        # x : (seq_len, batch_size, embedding_dim)
        x = self.transformer(x)
        # x : (seq_len, batch_size, embedding_dim)
        x = x.permute(1, 0, 2)
        # x : (batch_size, seq_len, embedding_dim)
        x = self.fc(x)
        # x : (batch_size, seq_len, 3)
        x = x[:, -1, :]
        # x : (batch_size, 3)
        return x
    
# create model
model = Transformer(len(vocab), 128, 8, 4, 0.2, max_len)
print(model)

TypeError: 'float' object cannot be interpreted as an integer