<a href="https://colab.research.google.com/github/middlebury-csci-0451/CSCI-0451/blob/main/lecture-notes/text-classification.ipynb" target="_parent">Open these notes in Google Colab</a>

<a href="https://colab.research.google.com/github/middlebury-csci-0451/CSCI-0451/blob/main/lecture-notes/text-classification-live.ipynb" target="_parent">Open the live version in Google Colab</a>


*Major components of this set of lecture notes are based on the [Text Classification](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html) tutorial from the PyTorch documentation*. 

## Deep Text Classification and Word Embedding

In this set of notes, we'll discuss the problem of *text classification*. Text classification is a common problem in which we aim to classify pieces of text into different categories. These categories might be about:

- **Subject matter**: is this news article about news, fashion, finance?
- **Emotional valence**: is this tweet happy or sad? Excited or calm? This particular class of questions is so important that it has its own name: sentiment analysis.
- **Automated content moderation**: is this Facebook comment a possible instance of abuse or harassment? Is this Reddit thread promoting violence? Is this email spam?

We saw text classification previously when we first considered the problem of vectorizing pieces of text. We are now going to look at a somewhat more contemporary approach to text using *word embeddings*. 


In [705]:
import pandas as pd
import torch
import numpy as np
from torchsummary import summary

# for embedding visualization later
import plotly.express as px 
import plotly.io as pio

# for VSCode plotly rendering
# pio.renderers.default = "plotly_mimetype+notebook_connected"
pio.renderers.default = "plotly_mimetype+notebook"

pio.templates.default = "plotly_white"

from sklearn.model_selection import train_test_split

import spacy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Our Task

Today, we are going to see whether we can teach an algorithm to understand and reproduce the pinnacle of cultural achievement; the benchmark against which all art is to be judged; the mirror that reveals to humany its truest self. I speak, of course, of *Star Trek: Deep Space Nine.*

<figure class="image" style="width:300px">
  <img src="https://raw.githubusercontent.com/PhilChodrow/PIC16B/master/_images/DS9.jpg" alt="">
  <figcaption><i></i></figcaption>
</figure>

In particular, we are going to attempt to teach a neural  network to generate *episode scripts*. This a text generation task: after training, our hope is that our model will be able to create scripts that are reasonably realistic in their appearance. 


In [780]:
## miscellaneous data cleaning

start_episode = 20 # Start in Season 2, Season 1 is not very good
num_episodes = 5  # only pick this many episodes to train on

url = "https://github.com/PhilChodrow/PIC16B/blob/master/datasets/star_trek_scripts.json?raw=true"
star_trek_scripts = pd.read_json(url)

cleaned = star_trek_scripts["DS9"].str.replace("\n\n\n\n\n\nThe Deep Space Nine Transcripts -", "")
cleaned = cleaned.str.split("\n\n\n\n\n\n\n").str.get(-2)
text = "\n\n".join(cleaned[start_episode:(start_episode + num_episodes)])
for char in ['\xa0', 'à', 'é', "}", "{"]:
    text = text.replace(char, "")

In [781]:
print(text[0:500])

  Last
time on Deep Space Nine.  
SISKO: This is the emblem of the Alliance for Global Unity. They call
themselves the Circle. 
O'BRIEN: What gives them the right to mess up our station? 
ODO: They're an extremist faction who believe in Bajor for the
Bajorans. 
SISKO: I can't loan you a Starfleet runabout without knowing where you
plan on taking it. 
KIRA: To Cardassia Four to rescue a Bajoran prisoner of war. 
(The prisoners are rescued.) 
KIRA: Come on. We have a ship waiting. 
JARO: What you 


In [782]:
def tokenizer(text):
    L = [s.split() for s in text.split("\n")]
    # return [w for l in L for w in l]
    out = L[0]
    for i in range(1, len(L)):
        out += ["\n"]
        out += L[i] 
    return out

In [783]:
len(tokenizer("Last \n time on Deep Space Nine. \n SISKO: This"))

10

In [784]:
WINDOW = 10 # predict next word from 10 previous words
word_seq = tokenizer(text)

predictors = []
targets    = []

for i in range(len(word_seq) - WINDOW - 1):
    predictors.append(" ".join(word_seq[i:(i+WINDOW)]))
    targets.append(word_seq[WINDOW+i])
    
i = 0
len(tokenizer(predictors[0])), targets[i]

(10, 'is')

In [785]:
from torch.utils.data import Dataset, DataLoader

class TextDataSet(Dataset):
    def __init__(self, predictors, targets):
        self.predictors = predictors
        self.targets = targets
    
    def __getitem__(self, index):
        return self.predictors[index], self.targets[index]

    def __len__(self):
        return len(self.targets)

In [786]:
data = TextDataSet(predictors, targets)

In [787]:
from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):
    for text, word in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(data), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [788]:
len(tokenizer(next(iter(data))[0]))

10

In [789]:
# define data loader

def collate_batch(batch):
    text_list, next_word_list = [], []
    for (text, next_word) in batch:
        processed_text = vocab(tokenizer(text))
        text_list.append(processed_text)
        next_word_list.append(vocab([next_word]))
    next_word_list = torch.tensor(next_word_list, dtype=torch.int64).squeeze()
    text_list = torch.tensor(text_list)
    return text_list.to(device), next_word_list.to(device)

data_loader = DataLoader(data, batch_size=8, shuffle=False, collate_fn=collate_batch)


In [837]:
text_list, next_word_list = next(iter(data_loader))

In [838]:
len(data_loader)

3730

In [866]:
# word embedding + LSTM?? feels like a lot, may need to adjust sequence size...
# or pivot to letters
from torch import nn

class TextGenModel(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, window):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim*window, hidden_size = 100, num_layers = 1, batch_first = True)
        self.fc   = nn.Linear(100, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x = torch.flatten(x, 1)
        x, (hn, cn) = self.lstm(x)
        x = self.fc(x)
        return(x)
        
EMBEDDING_DIM = 5
TGM = TextGenModel(len(vocab), EMBEDDING_DIM, WINDOW)

In [867]:
X, y = next(iter(data_loader))
loss_fn(TGM(X), y.squeeze())

tensor(8.5928, grad_fn=<NllLossBackward0>)

In [868]:
import time
def train(dataloader):
    # keep track of some counts for measuring accuracy
    total_acc, total_count, total_loss = 0, 0, 0
    log_interval = 1000
    start_time = time.time()

    for idx, (text_seq, next_word) in enumerate(dataloader):

        # zero gradients
        optimizer.zero_grad()
        # form prediction on batch
        preds = TGM(text_seq)
        # evaluate loss on prediction
        loss = loss_fn(preds, next_word)
        # compute gradient
        loss.backward()
        # take an optimization step
        optimizer.step()

        # for printing accuracy
        total_acc   += (preds.argmax(1) == next_word).sum().item()
        total_count += next_word.size(0)
        total_loss  += loss.item() 
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| train loss {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_loss/total_count))
            total_acc, total_loss, total_count = 0, 0, 0
            start_time = time.time()

In [880]:
optimizer = torch.optim.Adam(TGM.parameters(), lr=0.0001)
loss_fn = torch.nn.CrossEntropyLoss()

EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(data_loader)
    
    print('| end of epoch {:3d} | time: {:5.2f}s | '.format(epoch,
                                           time.time() - epoch_start_time))
    print('-' * 65)

| epoch   1 |  1000/ 3730 batches | train loss    0.610
| epoch   1 |  2000/ 3730 batches | train loss    0.663
| epoch   1 |  3000/ 3730 batches | train loss    0.684
| end of epoch   1 | time: 46.31s | 
-----------------------------------------------------------------
| epoch   2 |  1000/ 3730 batches | train loss    0.594
| epoch   2 |  2000/ 3730 batches | train loss    0.647
| epoch   2 |  3000/ 3730 batches | train loss    0.669
| end of epoch   2 | time: 48.70s | 
-----------------------------------------------------------------
| epoch   3 |  1000/ 3730 batches | train loss    0.588
| epoch   3 |  2000/ 3730 batches | train loss    0.637
| epoch   3 |  3000/ 3730 batches | train loss    0.658
| end of epoch   3 | time: 45.57s | 
-----------------------------------------------------------------
| epoch   4 |  1000/ 3730 batches | train loss    0.581
| epoch   4 |  2000/ 3730 batches | train loss    0.627
| epoch   4 |  3000/ 3730 batches | train loss    0.647
| end of epoch   4 

In [881]:
# preds = TGM(X[[0],:]).flatten()

all_words = vocab.get_itos()

def sample_from_preds(preds, temp = 1):
    probs = nn.Softmax(dim=0)(1/temp*preds)
    sampler = torch.utils.data.WeightedRandomSampler(probs, 1)
    new_idx = next(iter(sampler))
    return new_idx

def sample_next_word(text, temp = 1, window = 10):
    token_ix = vocab(tokenizer(text)[-window:])
    # return token_ix
    X = torch.tensor([token_ix], dtype = torch.int64)
    # return X
    preds = TGM(X).flatten()
    new_ix = sample_from_preds(preds, temp)
    return all_words[new_ix]

In [888]:
seed = 'Last time on Deep Space Nine. \n SISKO: This is the'

def sample_from_model(seed, n_words, temp, window):
    text = seed 
    for i in range(n_words):
        word = sample_next_word(text, temp, window)
        text += " " + word
    return seed, text    


synth = sample_from_model(seed, 500, .1, 10)


print(synth[0])
print("-"*50)
print(synth[1])

Last time on Deep Space Nine. 
 SISKO: This is the
--------------------------------------------------
Last time on Deep Space Nine. 
 SISKO: This is the emblem 
 BASHIR: 
 BASHIR: I don't 
 BASHIR: 
 BASHIR: I'm, 
 MELORA: I 
 BASHIR: I don't the think 
 BASHIR: [on 
 SISKO: I don't you 
 BASHIR: I don't 
 BASHIR: I 
 O'BRIEN: I don't 
 BASHIR: I don't to think 
 BASHIR: 
 BASHIR: I don't you 
 BASHIR: I don't 
 BASHIR: I 
 MELORA: I don't you 
 BASHIR: I don't 
 BASHIR: I 
 MELORA: I prophecies 
 BASHIR: I don't to think 
 BASHIR: I 
 MELORA: I 
 BASHIR: I don't the think 
 BASHIR: [on 
 BASHIR: I don't you 
 BASHIR: I don't you think 
 BASHIR: [on 
 BASHIR: I 
 BASHIR: I don't 
 BASHIR: I don't to your 
 BASHIR: 
 BASHIR: I don't 
 BASHIR: 
 BASHIR: I'm, 
 O'BRIEN: I 
 BASHIR: I don't the think 
 
 BASHIR: 
 
 BASHIR: I don't 
 BASHIR: I don't of 
 MELORA: I 
 BASHIR: I don't 
 BASHIR: Well, 
 MELORA: I don't you 
 BASHIR: I don't you job, 
 BASHIR: 
 BASHIR: I don't 
 BASHIR: 
 BASH

In [727]:
all_words

['<unk>',
 '\n',
 'the',
 'to',
 'I',
 'you',
 'a',
 'of',
 'and',
 'SISKO:',
 'is',
 'in',
 'that',
 'have',
 'be',
 'for',
 'KIRA:',
 'it',
 'BASHIR:',
 'QUARK:',
 'on',
 "I'm",
 'ODO:',
 'your',
 'DAX:',
 "O'BRIEN:",
 'You',
 'was',
 'with',
 'we',
 'my',
 'this',
 'The',
 'are',
 'not',
 "don't",
 'me',
 'do',
 'know',
 'can',
 'about',
 'what',
 'all',
 'but',
 'just',
 'as',
 'at',
 'get',
 'going',
 "It's",
 'you.',
 'like',
 'he',
 'been',
 'if',
 'one',
 'And',
 'his',
 'want',
 'from',
 'What',
 'it.',
 'think',
 'out',
 'an',
 'But',
 'will',
 'GARAK:',
 'were',
 'they',
 'We',
 'me.',
 'would',
 'It',
 "you're",
 "I'll",
 'our',
 "I've",
 'If',
 'no',
 'see',
 'up',
 "That's",
 'him',
 'has',
 'some',
 'could',
 'JAKE:',
 'so',
 'by',
 "it's",
 'never',
 'had',
 'them',
 'He',
 'any',
 'into',
 'us',
 'there',
 'when',
 'her',
 'back',
 'here',
 'Well,',
 'who',
 'time',
 'here.',
 "can't",
 'you,',
 'how',
 "You're",
 'more',
 'take',
 'A',
 'ROM:',
 'or',
 'got',
 'their'