In [1]:
# Dataload

import pandas as pd

df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")
df = df["tweet_text"]
for i in range(len(df)):
    tweet = df.iloc[i][2:len(df.iloc[i])-1]    
    df.iloc[i] = tweet

print(f"Dataset Length: {len(df)}")
df.head()

  df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")


Dataset Length: 67585


0    RT @TygressAndy: Her killer\xe2\x80\x99s famil...
1    my misandry doesn't go unjustified. \n#menaret...
2    RT @zozitunzi: My little sister's friend, a be...
3    RT @MatlhagaKebo: \xe2\x80\x9cWhy don\xe2\x80\...
4    RT @ElihleGwala: My heart bleeds for Kwasa\xe2...
Name: tweet_text, dtype: object

In [3]:
# Dataset

import torch
from collections import Counter
import re

def cleanText(t):
    t = re.sub(r'http\S+', '', t)
    t = re.sub(r"\[A-Za-z]*\.com", " ", t)
    return t.replace(".", " . ").replace(",", " , ").replace(";", " ; ").replace("?", " ? ").replace("!", " ! ").replace("(", " (").replace(")", ") ").replace("...", " ... ").replace("\"", " \"")

class Dataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length=100):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
        print(f"Vocab Size: {len(self.uniq_words)}\n")
        
    def load_words(self):
        text = df.str.cat(sep=" ").lower()
        text = cleanText(text)
        text = text.split(" ")
        text = [x.encode("ascii", "ignore") for x in text]
        text = [x.decode() for x in text if not "\\x" in x.decode()]
        text = [x for x in text if not x == ""]
        text = [x for x in text if not "#" in x and not "@" in x and not "http" in x and "\\n" not in x and not x == "rt"]
        return text
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
    
dataset = Dataset(sequence_length=100)

Vocab Size: 17741



In [4]:
# Model

from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 5
        
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2
        )
        self.fc = nn.Linear(self.lstm_size, self.lstm_size*2)
        self.fc2 = nn.Linear(self.lstm_size*2, self.lstm_size)
        self.fc3 = nn.Linear(self.lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        output = self.fc(output)
        output = self.fc2(output)
        logits = self.fc3(output)
        return logits, state
    
    def init_state(self, sequence_length):
        return (
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
        )

In [10]:
# Train

import numpy as np
from torch import optim
from torch.utils.data import DataLoader
import wandb

# Main

# Best Settings so Far
# Model:
# 5 LSTM Layers, embed_size and lstm_size of 128
# 3 Linear FC Layers lstm_size => *2 => *1 => n__vocab
# 
# Hparams: 
# SEQ_LEN 100
# Batch 64 - increasing doesn't really help
# Adam LR 1e-3

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 25
SEQ_LEN = 100
BATCH_SIZE = 64
LOAD_MODEL=False
LOAD_CHECKPOINT="twitter_lstm_10e.pth"

dataset = Dataset(sequence_length=SEQ_LEN)
model = Model(dataset)
model.to(DEVICE)

def train(dataset, model, num_epochs=25, sequence_length=100, batch_size=64, device="cpu"):
    print("Training")
    wandb.init(entity="parabyl", project="ZA Twitter LSTM", name=f"5l, fc3, seq100, lr 1e-3 more clean")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    if LOAD_MODEL == True:
        checkpoint=torch.load(LOAD_CHECKPOINT, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        last_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Model Loaded")
        print(last_epoch)
        print(last_loss)
        
    
    model.train()
        
    for epoch in range(num_epochs):
        state_h, state_c = model.init_state(sequence_length)
        state_h, state_c = state_h.to(device), state_c.to(device)
        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1,2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            
            wandb.log({
                "Loss": loss.item(),
                "epoch": epoch,
            })
            
            if batch % 25 == 0:
                print(F"e: {epoch} | l: {loss.item()} | b: {batch}/{len(dataloader)}")
                text = " ".join(predict(dataset, model, "what does", device))
                print(text)
                wandb.log({
                    "Predictions": text,
                })
                model.train()
                
            
            if batch % 100 == 0: 
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'loss': loss, 
                }, "twitter_lstm_10plus.pth")
        
        
def predict(dataset, model, text, device, next_words=20):
    if LOAD_MODEL == True:
        print("Loading model for Prediction")
        checkpoint=torch.load(LOAD_CHECKPOINT, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
    
    words = text.split(" ")
    state_h, state_c = model.init_state(len(words))
    state_h, state_c = state_h.to(device), state_c.to(device)
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1].detach()
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        
    return words

# prompts = [
#     "lasizwe is a",
#     "i am"
# ]

# for p in range(len(prompts)):
#     prompt = prompts[p]
#     autocomplete = predict(dataset, model, prompt, device=DEVICE, next_words=100-len(prompt))
#     print(" ".join(autocomplete))
#     print()
    
train(dataset, model, num_epochs=NUM_EPOCHS, sequence_length=SEQ_LEN, batch_size=BATCH_SIZE, device=DEVICE)

Vocab Size: 17741

Training


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Loss,███▇▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Loss,6.37298
Predictions,what does he'd revea...
epoch,0


e: 0 | l: 9.786608695983887 | b: 0/17313
what does mjita fat peers blue actual priorities perpetrators dudes 6:00 affirmative book accord sopie mum facibg grave dressing wonderkids "our mee
e: 0 | l: 6.590362071990967 | b: 25/17313
what does berry one puntshununu things and are as i blm sport carrying rt like innocent with out my ? in nasty
e: 0 | l: 6.366405010223389 | b: 50/17313
what does dodgy themselves: mosadi rt rt rt must , the a jack should with a call happening easy at rt cackling
e: 0 | l: 6.561718940734863 | b: 75/17313
what does 20x28inches showing letterhead performed go the , a here billion . ! rt sense have our a that 7 me


KeyboardInterrupt: 