<a href="https://colab.research.google.com/github/gvogiatzis/CS4740/blob/main/CS4740_Lab_Week_05.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import re
import pandas as pd
from textblob import Word
import numpy as np

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from more_itertools import sliced
%load_ext tensorboard

In [None]:
! wget https://github.com/suraj-deshmukh/BBC-Dataset-News-Classification/raw/master/dataset/dataset.csv -O dataset.csv

In [None]:
raw_data = pd.read_csv('dataset.csv', encoding = "ISO-8859-1")
docs_txt = raw_data['news'].tolist()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
all_text = "".join(docs_txt)
all_text = all_text[:1000000]

In [None]:
len(all_text)

In [None]:
itoc = list(sorted(set(all_text)))
ctoi = {c:i for i,c in enumerate(itoc)}
num_of_characters = len(itoc)

In [None]:
import random
def random_text(size, text_data):
    i = random.randint(0,len(text_data)-1-size)
    return text_data[i:i+size]

In [None]:
class LSTMCharPred(nn.Module):
    def __init__(self, charset_size, embed_size=100, hidden_dim=512*2):
        super(LSTMCharPred, self).__init__()
        self.embedding = nn.Embedding(charset_size, embed_size)
        self.charset_size = charset_size
        self.hidden_dim = hidden_dim
        # self.lstm = nn.LSTM(input_size=embed_size,
        self.lstm = nn.GRU(input_size=embed_size,
                            hidden_size=hidden_dim,
                            num_layers=1,
                            batch_first=True)
                            # dropout=0.5)
        self.fc = nn.Linear(hidden_dim, charset_size)

    def forward(self, x, batch_size=1):
        x = self.embedding(x.view(batch_size,-1))
        x, _ = self.lstm(x)
        x = self.fc(x)
        # return x.view(batch_size,-1,self.charset_size)
        return x.view(-1,self.charset_size)

In [None]:
class RunningAverage:
    def __init__(self):
        self.n=0
        self.tot=0
    
    def add(self,x):
        self.n += 1
        self.tot += x
        
    def __call__(self):
        return self.tot/self.n

In [None]:
tensorboard --logdir=runs

In [None]:
num_of_epochs = 20
seq_length = 100
batch_size=64
embed_size=100
hidden_dim=512*2
net = LSTMCharPred(embed_size=embed_size, hidden_dim=hidden_dim, charset_size = num_of_characters).to(device)
loss = nn.CrossEntropyLoss()
optim = torch.optim.Adam(net.parameters(), lr=0.001) 
net.train()
max_iter = int(len(all_text)/(batch_size*seq_length+1))

for e in range(num_of_epochs):
    train_acc = RunningAverage()
    for i,txt in enumerate(sliced(all_text, batch_size*seq_length+1)):
        if len(txt)<batch_size*seq_length+1:
            break
        txt = random_text(batch_size*seq_length+1,all_text)
        x = torch.tensor([ctoi[c] for c in txt[:-1]], device = device)
        t = torch.tensor([ctoi[c] for c in txt[1:]], device = device)
        optim.zero_grad()
        y = net(x,batch_size)
        L = loss(y, t)
        acc = sum(y.argmax(dim=1)==t).item()/(batch_size*seq_length)
        train_acc.add(acc)
        print(f"\rEpoch: {e}/{num_of_epochs} Iter: {i}/{max_iter}\tacc={100*acc:0.2f}%\tL={L}", end="")
        # net.train(True)
        L.backward()
        optim.step()
    print(f"\rEpoch: {e}/{num_of_epochs} Average acc: {train_acc()}")

In [None]:
def generate_text(net, seed_txt, length=100):
    seed_lst_idx = [ctoi[c] for c in seed_txt]
    seed_idx = torch.tensor(seed_lst_idx, device = device)
    # generated_text = generate_text(net, seed_idx, length=100)
    net.train(False)
    x=net.embedding(seed_idx.view(1,-1))
    x,h = net.lstm(x)
    x = net.fc(x)
    out = x.argmax(dim=2).view(-1)
    x = out[-1]
    generated_text=[x.item()]
    for i in range(length):
        x=net.embedding(x.view(1,-1))
        x,h = net.lstm(x,h)
        x = net.fc(x)
        x = x.argmax(dim=2).view(-1)
        generated_text.append(x.item())
    return  "".join(itoc[i] for i in seed_lst_idx+generated_text)


In [None]:
print(generate_text(net, "Garments", length=200))

In [None]:
idx=all_text.index("Garments")
all_text[idx-100:idx+100]