In [1]:
# Dataload

import pandas as pd

df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")
df = df["tweet_text"]
for i in range(len(df)):
    tweet = df.iloc[i][2:len(df.iloc[i])-1]    
    df.iloc[i] = tweet

print(f"Dataset Length: {len(df)}")
df.head()

  df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")


Dataset Length: 67585


0    RT @TygressAndy: Her killer\xe2\x80\x99s famil...
1    my misandry doesn't go unjustified. \n#menaret...
2    RT @zozitunzi: My little sister's friend, a be...
3    RT @MatlhagaKebo: \xe2\x80\x9cWhy don\xe2\x80\...
4    RT @ElihleGwala: My heart bleeds for Kwasa\xe2...
Name: tweet_text, dtype: object

In [2]:
# Dataset

import torch
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length=4):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
        print(f"Word List: {len(self.uniq_words)}")
        
    def load_words(self):
        text = df.str.cat(sep=" ")
        # Filters
        text = text.split(" ")
        text = [x.encode("ascii", "ignore") for x in text]
        text = [x.decode() for x in text if not "\\x" in x.decode()]
        text = [x for x in text if not "#" in x and not "@" in x and not "https://" in x and "\\n" not in x and not x == "RT"]
        text = [x.lower() for x in text]
        return text
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [3]:
# Model

from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 5
        
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2
        )
        self.fc = nn.Linear(self.lstm_size, self.lstm_size*2)
        self.fc2 = nn.Linear(self.lstm_size*2, self.lstm_size)
        self.fc3 = nn.Linear(self.lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        output = self.fc(output)
        output = self.fc2(output)
        logits = self.fc3(output)
        return logits, state
    
    def init_state(self, sequence_length):
        return (
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
        )

In [4]:
# Train

import numpy as np
from torch import optim
from torch.utils.data import DataLoader
import wandb

def train(dataset, model, num_epochs=10, sequence_length=4, batch_size=256, device="cpu"):
    wandb.init(entity="parabyl", project="ZA Twitter LSTM", name=f"ten_layer")
    
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
        
    for epoch in range(num_epochs):
        state_h, state_c = model.init_state(sequence_length)
        state_h, state_c = state_h.to(device), state_c.to(device)
        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1,2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            
            wandb.log({
                "Loss": loss.item(),
                "epoch": epoch,
            })
            
            if batch % 25 == 0:
                print(F"e: {epoch} | l: {loss.item()} | b: {batch}/{len(dataloader)}")
                text = " ".join(predict(dataset, model, "what does", device))
                print(text)
                wandb.log({
                    "Predictions": text,
                })
            
            if batch % 100 == 0: 
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'loss': loss, 
                }, "twitter_lstm.pth")
        
        
def predict(dataset, model, text, device, next_words=20):
    model.eval()
    
    words = text.split(" ")
    state_h, state_c = model.init_state(len(words))
    state_h, state_c = state_h.to(device), state_c.to(device)
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1].detach()
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
    
    model.train()

    return words

# Main

# Next Steps:
# Clean up the data by removing non-letter characters.
# Increase the model capacity by adding more Linear or LSTM layers.
# Split the dataset into train, test, and validation sets.
# Add checkpoints so you don't have to train the model every time you want to run prediction.

NUM_EPOCHS = 25
SEQ_LEN = 100
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = Dataset(sequence_length=SEQ_LEN)
model = Model(dataset)
model.to(DEVICE)

train(dataset, model, num_epochs=NUM_EPOCHS, sequence_length=SEQ_LEN, batch_size=BATCH_SIZE, device=DEVICE)

e: 5 | l: 3.3535208702087402 | b: 6900/14634
what does not fit psg (france) woohoo! sooo exciting! we saying "i actually learn to they are johnnie walkers. we like and
e: 5 | l: 3.346433639526367 | b: 6925/14634
what does i had her problems are the vehicle the police use black president must be aware of we sab hs proven
e: 5 | l: 3.4160943031311035 | b: 6950/14634
what does not make us that this is here are must be on makes things, means that is being labeled. women, or
e: 5 | l: 3.503568649291992 | b: 6975/14634
what does so did the 112th premier class race! the country than never. yourself as it's normal sexualities into another happy friday
e: 5 | l: 3.3470425605773926 | b: 7000/14634
what does a the reason why is happening in do this matter. 3 cops hate years was stabbed a few in rising
e: 5 | l: 3.1848278045654297 | b: 7025/14634
what does twitterically we can u know well on his from luckyekeh  ....if it was a lil' breather from the mess
e: 5 | l: 3.281407356262207 | b: 7050/14634

KeyboardInterrupt: 