In [1]:
# Dataload

import pandas as pd

df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")
df = df["tweet_text"]
for i in range(len(df)):
    tweet = df.iloc[i][2:len(df.iloc[i])-1]    
    df.iloc[i] = tweet

print(f"Dataset Length: {len(df)}")
df.head()

  df = pd.read_csv("./../_datasets/twitter_za/RSA_tweet_data.csv")


Dataset Length: 67585


0    RT @TygressAndy: Her killer\xe2\x80\x99s famil...
1    my misandry doesn't go unjustified. \n#menaret...
2    RT @zozitunzi: My little sister's friend, a be...
3    RT @MatlhagaKebo: \xe2\x80\x9cWhy don\xe2\x80\...
4    RT @ElihleGwala: My heart bleeds for Kwasa\xe2...
Name: tweet_text, dtype: object

In [2]:
# Dataset

import torch
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(self, sequence_length=4):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        
        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        
        self.words_indexes = [self.word_to_index[w] for w in self.words]
        
        print(f"Word List: {len(self.uniq_words)}")
        
    def load_words(self):
        text = df.str.cat(sep=" ")
        # Filters
        text = text.split(" ")
        text = [x.encode("ascii", "ignore") for x in text]
        text = [x.decode() for x in text if not "\\x" in x.decode()]
        text = [x for x in text if not "#" in x and not "@" in x and not "https://" in x and "\\n" not in x and not x == "RT"]
        text = [x.lower() for x in text]
        return text
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)
    
    def __len__(self):
        return len(self.words_indexes) - self.sequence_length
    
    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [3]:
# Model

from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 256
        self.embedding_dim = 256
        self.num_layers = 10
        
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state
    
    def init_state(self, sequence_length):
        return (
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
            torch.zeros(self.num_layers, sequence_length, self.lstm_size),
        )

In [None]:
# Train

import argparse
import numpy as np
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

def train(dataset, model, num_epochs=10, sequence_length=4, batch_size=256, device="cpu"):
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
        
    for epoch in tqdm(range(num_epochs)):
        state_h, state_c = model.init_state(sequence_length)
        
        for batch, (x, y) in enumerate(dataloader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1,2), y)
            
            state_h = state_h.detach()
            state_c = state_c.detach()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 10 == 0:
                print()
                print("=== === === === === === === === === === === === === === === ===")
                print()
                print(f"e: {epoch} - b: {batch} - l: {loss.item()}")
                print()
                print(" ".join(predict(dataset, model, "what does")))
                print()
                print("=== === === === === === === === === === === === === === === ===")
                print()
        
        
def predict(dataset, model, text, next_words=20):
    model.eval()
    
    words = text.split(" ")
    state_h, state_c = model.init_state(len(words))
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
    
    model.train()

    return words

# Main

# Next Steps:
# Clean up the data by removing non-letter characters.
# Increase the model capacity by adding more Linear or LSTM layers.
# Split the dataset into train, test, and validation sets.
# Add checkpoints so you don't have to train the model every time you want to run prediction.


NUM_EPOCHS = 10
SEQ_LEN = 20
BATCH_SIZE = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = Dataset(sequence_length=SEQ_LEN)
model = Model(dataset)
model.to(DEVICE)

train(dataset, model, num_epochs=NUM_EPOCHS, sequence_length=SEQ_LEN, batch_size=BATCH_SIZE, device=DEVICE)

Word List: 23669


  0%|                                                                                                                                            | 0/10 [00:00<?, ?it/s]


=== === === === === === === === === === === === === === === ===

e: 0 - b: 0 - l: 10.07760238647461

what does dealers womxn times) really? seconds. pipe uyazi disagreement brass horizon? mistaken. poison; scientist farmers. lisbon, awesome! rap, swing clients, dembele

=== === === === === === === === === === === === === === === ===


=== === === === === === === === === === === === === === === ===

e: 0 - b: 10 - l: 8.141050338745117

what does legally creeps vga alongside 1 hle... heals, amberwood transfer withdraw ???? shame. im to rest grew 1800ksh nomination. investigated. shot,

=== === === === === === === === === === === === === === === ===


=== === === === === === === === === === === === === === === ===

e: 0 - b: 20 - l: 5.3764848709106445

what does choices bullrush. of yhoo 19 death stabbed to to demanded dumped. spot by by make persistence she giving losses a

=== === === === === === === === === === === === === === === ===


=== === === === === === === === === === === === 