In [1]:
import torch
from torchsummary import summary
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.functional as F
from sentence_transformers import SentenceTransformer
from tqdm import tqdm



import pandas as pd
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


# StarWars



In [2]:
base_dir = "../data/raw/StarWarsEpisodes"

In [3]:
folder_ep4 = os.path.join(base_dir,"SW_EpisodeIV.txt")
folder_ep5 = os.path.join(base_dir,"SW_EpisodeV.txt")
folder_ep6 = os.path.join(base_dir,"SW_EpisodeVI.txt")

In [4]:
df_ep4 = pd.read_csv(folder_ep4, sep =' ', header=0, escapechar='\\')
df_ep5 = pd.read_csv(folder_ep5, sep =' ', header=0, escapechar='\\')
df_ep6 = pd.read_csv(folder_ep6, sep =' ', header=0, escapechar='\\')


In [5]:
df_ep4


Unnamed: 0,character,dialogue
1,THREEPIO,Did you hear that? They've shut down the main...
2,THREEPIO,We're doomed!
3,THREEPIO,There'll be no escape for the Princess this time.
4,THREEPIO,What's that?
5,THREEPIO,I should have known better than to trust the l...
...,...,...
1006,LUKE,"Oh, no!"
1007,THREEPIO,"Oh, my! Artoo! Can you hear me? Say somethi..."
1008,TECHNICIAN,We'll get to work on him right away.
1009,THREEPIO,"You must repair him! Sir, if any of my circui..."


In [6]:
Y = pd.concat([df_ep4['character'],df_ep5['character'],df_ep6['character']]).tolist()
X = pd.concat([df_ep4['dialogue'],df_ep5['dialogue'],df_ep6['dialogue']]).tolist()

In [7]:
labels = np.unique(Y)
label_count = [sum(i == np.array(Y)) for i in labels]
for i,(a,b) in enumerate(zip(labels,label_count)):
    if b < 10:
        labels[i] = "Other"
labels = np.unique(labels)


In [8]:
labels

array(['ACKBAR', 'BEN', 'BIGGS', 'COMMANDER', 'CREATURE', 'EMPEROR',
       'GOLD LEADER', 'HAN', 'JABBA', 'LANDO', 'LEIA', 'LUKE', 'OFFICER',
       'OWEN', 'Other', 'PIETT', 'RED LEADER', 'RIEEKAN', 'TARKIN',
       'THREEPIO', 'TROOPER', 'VADER', 'WEDGE', 'YODA'], dtype='<U30')

In [9]:
char2ind = {i:j for i,j in zip(labels,range(len(labels)))}
ind2char = {j:i for i,j in zip(labels,range(len(labels)))}

In [10]:
# char_names = movie_lines.iloc[:,0]
# movie_names = movie_lines.iloc[:,1]
# char_names = np.unique(list(set(char_names.values)))
# movie_names = np.unique(list(set(movie_names.values)))

In [19]:
class CustomStarWarsDataset(Dataset):
    def __init__(self, X, Y,transform=None, target_transform=None):
        self.X = X
        self.Y = Y
        self.sentence_model = SentenceTransformer('bert-base-nli-mean-tokens')
        for param in self.sentence_model.parameters():
            param.requires_grad = False 
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        data_point = self.X[idx]
        if self.Y[idx] in labels:
            label_point = char2ind[self.Y[idx]]
        else:
            label_point = char2ind["Other"]
        # print("data_point is:",data_point)
        # print("label_point is:",label_point)
        sentence_encoded = self.sentence_model.encode(data_point)
        # 768 
        # print(sentence_encoded,label_point)
        return sentence_encoded, label_point

In [20]:
len(Y)

2523

In [31]:
class BertSentenceClassifier(nn.Module):
    def __init__(self):
        super(BertSentenceClassifier, self).__init__()

        self.lin1 = nn.Linear(768,64)
        self.lin2 = nn.Linear(64, len(labels))



    def forward(self, data):
        x = nn.functional.relu(self.lin1(data))
        x = self.lin2(x)

        x = nn.functional.softmax(x,dim = 1)

        
        return x

In [37]:
train_dataset = CustomStarWarsDataset(X,Y)
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,drop_last=True)

In [42]:
model = BertSentenceClassifier()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [43]:
summary(model)

Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            49,216
├─Linear: 1-2                            1,560
Total params: 50,776
Trainable params: 50,776
Non-trainable params: 0


Layer (type:depth-idx)                   Param #
├─Linear: 1-1                            49,216
├─Linear: 1-2                            1,560
Total params: 50,776
Trainable params: 50,776
Non-trainable params: 0

In [44]:
device = "mps" if torch.backends.mps.is_available() else "cpu"


In [45]:

num_epochs = 3
print(device)
model = model.to(device)

for epoch in range(num_epochs):
    loop = tqdm(train_loader, total=len(train_loader))

    for text, author_labels in loop:  # Assuming data_loader is set up to provide batches of data
        data = text.to(device)
        author_labels = author_labels.to(device)
        optimizer.zero_grad()

        data = model(data)
        # print(data.shape)
        # print("data and author labels ::::: ",data,author_labels)
        loss = nn.CrossEntropyLoss()(data,author_labels)

        loss.backward()

        optimizer.step()
        loop.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())


    print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")


mps


Epoch [1/3]:  81%|████████  | 511/630 [01:38<00:23,  5.16it/s, loss=2.96]


KeyboardInterrupt: 