<a href="https://colab.research.google.com/github/esfandiaryfard/machine-learning/blob/main/freinds.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
import urllib

urllib.request.urlretrieve(
    'https://drive.google.com/uc?export=download&id=1f10Zoa_Lqg82-BnFLCHavVWig3Ei2fMV', 
    'script.txt'
    )

('script.txt', <http.client.HTTPMessage at 0x7fbb307f9190>)

In [2]:
data_path = "script.txt"
batch_size = 8
batch_seq_len = 64
embed_size = 512
rnn_size = 1024

with open(data_path) as f:
    text = f.read()
text = text[180:]

token_dict = {".": "|fullstop|",
              ",": "|comma|",
              "\"": "|quote|",
              ";": "|semicolon|",
              "!": "|exclamation|",
              "?": "|question|",
              "(": "|leftparen|",
              ")": "|rightparen|",
              "--": "|dash|",
              "\n": "|newline|"
}
for punct, token in token_dict.items():
    text = text.replace(punct, f' {token} ')

text[:200]

"Monica: There's nothing to tell |exclamation|  He's just some guy I work with |exclamation|  |newline| Joey: C'mon |comma|  you're going out with the guy |exclamation|  There's gotta be something wron"

In [3]:
words = text.split(" ")
words = [word for word in words if len(word) > 0]
vocab = list(set(words))

for i, w in enumerate(vocab[:5]):
  print(i, w)

0 cheap
1 final
2 rarely
3 quantity
4 Brosnan


In [4]:
vocab_to_int = {word: i for i,word in enumerate(vocab)}
int_to_vocab = {i: word for i,word in enumerate(vocab)}

num_words = len(vocab)
print(num_words)

27750


In [5]:
print(len([word for word in text.split(" ") if len(word) > 0]))

1207503


In [6]:
text_ints = [vocab_to_int[word] for word in text.split(" ") if len(word) > 0]
len(text_ints)

1207503

In [7]:
import re
scene = re.findall(r'\[Scene.*?\]', text)
num_scenes = len(scene)
print(len(text_ints)/num_scenes)

389.8944139489829


In [8]:
new_text = [word for word in text.split(" ") if len(word) > 0]
inputs = new_text[:5]
target = new_text[1:5]

print(inputs)
print(target)

['Monica:', "There's", 'nothing', 'to', 'tell']
["There's", 'nothing', 'to', 'tell']


In [9]:
scene_length = 400

In [10]:
def get_batches(text_ints, scene_length, batch_size, batch_seq_len):
    num_scenes = len(text_ints)//scene_length
    text_targets = text_ints[1:] + [text_ints[0]]
    scene_inputs = [
                    text_ints[i * scene_length : (i+1) * scene_length] for i in 
                    range(num_scenes)
                    ]
    scene_targets = [
                     text_targets[i*scene_length:(i+1)*scene_length] for i in 
                     range(num_scenes)
                     ]
    num_mini_sequences = scene_length//batch_seq_len
    scene_inputs = [
                    [scene[i*batch_seq_len:(i+1)*batch_seq_len] for i in 
                     range(num_mini_sequences)] for scene in scene_inputs
                    ]
    scene_targets = [
                     [scene[i*batch_seq_len:(i+1)*batch_seq_len] for i in 
                      range(num_mini_sequences)] for scene in scene_targets
                     ]
    num_batch_groups = len(scene_inputs)//batch_size
    batches = []
    for i in range(num_batch_groups):
        group_scene_inputs = scene_inputs[i*batch_size:(i+1)*batch_size]
        group_scene_targets = scene_targets[i*batch_size:(i+1)*batch_size]
        for j in range(num_mini_sequences):
            reset_state = (j == 0)
            batch_inputs = torch.LongTensor(
                [group_scene_inputs[k][j] for k in range(batch_size)]
                )
            batch_targets = torch.LongTensor(
                [group_scene_targets[k][j] for k in range(batch_size)]
                )
            batches.append((reset_state, batch_inputs, batch_targets))
    return batches

In [11]:
batches = get_batches(text_ints, scene_length, batch_size, batch_seq_len)
batches[0][1].shape

torch.Size([8, 64])

In [12]:
script = [  int_to_vocab[y.item()] for y in [x for x in batches[1][1][3]] ]
script

['not',
 'be',
 'wearing',
 'those',
 'pants',
 '|fullstop|',
 '|newline|',
 'Joey:',
 'I',
 'say',
 'push',
 'her',
 'down',
 'the',
 'stairs',
 '|fullstop|',
 '|newline|',
 'Phoebe',
 '|comma|',
 'Ross',
 '|comma|',
 'Chandler',
 '|comma|',
 'and',
 'Joey:',
 'Push',
 'her',
 'down',
 'the',
 'stairs',
 '|exclamation|',
 'Push',
 'her',
 'down',
 'the',
 'stairs',
 '|exclamation|',
 'Push',
 'her',
 'down',
 'the',
 'stairs',
 '|exclamation|',
 '|newline|',
 '|leftparen|',
 'She',
 'is',
 'pushed',
 'down',
 'the',
 'stairs',
 'and',
 'everyone',
 'cheers',
 '|fullstop|',
 '|rightparen|',
 '|newline|',
 'Rachel:',
 "C'mon",
 'Daddy',
 '|comma|',
 'listen',
 'to',
 'me']

In [13]:
class Model(nn.Module):
    
    def __init__(self, num_words, embed_size, rnn_size):
        super().__init__()
        self.rnn_size = rnn_size
        self.state = None
        self.embedding = nn.Embedding(num_words, embed_size)
        self.rnn = nn.LSTM(embed_size, rnn_size, batch_first=True)
        self.decoder = nn.Linear(rnn_size, num_words)
        self.reset_next_state = False
        
    def reset_state(self):
        self.reset_next_state = True
        
    def forward(self, x):
        if self.reset_next_state:
            self.state = (
                x.new_zeros(1, x.size(0), self.rnn_size).float(), 
                x.new_zeros(1, x.size(0), self.rnn_size).float())
            self.reset_next_state = False
        x = self.embedding(x)
        state = self.state if self.state is not None else None
        x, state = self.rnn(x, state)
        self.state = (state[0].data, state[1].data)
        x = self.decoder(x)
        return x

In [14]:
model = Model(num_words, embed_size, rnn_size)

In [15]:
# Setup device
dev = torch.device("cuda")

In [16]:
# Move model to device
model = model.to(dev)

In [17]:
# Define script generation function
def generate_script(model, seq_len, script_start):
    # Convert punctuaction in script start
    for punct, token in token_dict.items():
        script_start = script_start.replace(punct, f' {token} ')
    # Convert script start text to ints
    script_start = [vocab_to_int[word] for word in script_start.split(" ") if len(word) > 0]
    # Initialize output words/tokens
    script = script_start[:]
    # Convert script start to tensor (BxS = 1xS)
    script_start = torch.LongTensor(script_start).unsqueeze_(0)
    # Process script start and generate the rest of the script
    model.eval()
    model.reset_state()
    input = script_start
    for i in range(seq_len - script_start.size(1) + 1): # we include script_start as one of the generation steps
        # Copy input to device
        input = input.to(dev)
        # Pass to model
        output = model(input) # 1xSxV
        # Convert to word indexes
        words = output.max(2)[1] # 1xS
        words = words[0] # S
        # Add each word to script
        for j in range(words.size(0)):
            script.append(words[j].item())
        # Prepare next input
        input = torch.LongTensor([words[-1]]).unsqueeze(0) # 1xS = 1x1
    # Convert word indexes to text
    script = ' '.join([int_to_vocab[x] for x in script])
    # Convert punctuation tokens to symbols
    for punct,token in token_dict.items():
        script = script.replace(f"{token}", punct)
    # Return
    return script

In [18]:
generate_script(model, 20, "Monica: You can't believe it")

"Monica: You can't believe it w-a-n-t Netherlands anal summers quality symbolism Isn't muddy protectors Graff outlet ow ow hungover deviated SIMPLE academy dancer's conveniently Andrew"

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [21]:
# Initialize training history
loss_history = []
# Start training
for epoch in range(20):
    # Initialize accumulators for computing average loss/accuracy
    epoch_loss_sum = 0
    epoch_loss_cnt = 0
    # Set network mode
    model.train()
    # Process all batches
    for i,batch in enumerate(batches):
        # Parse batch
        reset_state, input, target = batch
        # Check reset state
        if reset_state:
            model.reset_state()
        # Move to device
        input = input.to(dev)
        target = target.to(dev)
        # Forward
        output = model(input)
        # Compute loss
        output = output.view(-1, num_words)
        target = target.view(-1)
        loss = F.cross_entropy(output, target)
        # Update loss sum
        epoch_loss_sum += loss.item()
        epoch_loss_cnt += 1
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Shift sequence and recompute batches
    shift_point = random.randint(1, len(text_ints)-1)
    text_ints = text_ints[:shift_point] + text_ints[shift_point:]
    batches = get_batches(text_ints, scene_length, batch_size, batch_seq_len)
    # Epoch end - compute average epoch loss
    avg_loss = epoch_loss_sum/epoch_loss_cnt
    print(f"Epoch: {epoch+1}, loss: {epoch_loss_sum/epoch_loss_cnt:.4f}")
    print("Test sample:")
    print("---------------------------------------------------------------")
    print(generate_script(model, scene_length, "Ross:"))
    print("---------------------------------------------------------------")
    # Add to histories
    loss_history.append(avg_loss)

Epoch: 1, loss: 1.3750
Test sample:
---------------------------------------------------------------
Ross: Hey , where's Chandler ? 
 Joey: Yeah . 
 Ross: Oh , you know , you haven't seen her like this ! 
 Rachel: Oh ! 
 Ross: Oh , I know . . . I know . . . 
 Rachel: Oh my God ! 
 Ross: What ? 
 Rachel: Oh my God ! 
 Phoebe: What ? 
 Joey: I don't know . . . I mean , I don't wanna look in a good way . 
 Phoebe: Yeah , but you're back's a zero . You're not going to Paris . 
 Joey: What ? 
 Ross: I don't know . 
 Rachel: Ross , I don't even know how that I feel about you . 
 Ross: Yeah , yeah , but he did ! 
 Joey: Yeah , you know , I mean , you know , you and I used them all in school . 
 Phoebe: Yeah , I mean , you know , you did what you thought it would be . 
 ( Rachel enters . ) 
 Rachel: Hey . 
 Phoebe: Hey . 
 Monica: Hey . 
 Chandler: Hey . 
 Ross: Hey . 
 Rachel: Hey . 
 Ross: Hey . 
 Chandler: Hey . 
 Ross: Hey . 
 Rachel: Oh , hi Joey . 
 Ross: Hey . 
 Rachel: Hi . 
 Ross: Gues

In [22]:
# Generate script
print(generate_script(model, scene_length, "Monica: You can't believe it"))

Monica: You can't believe it ( don't move in ! I have a job , I have to get back to the hotel . 
 Joey: I love it ! 
 Ross: Yeah , but you can't because she's here . 
 Joey: ( taking apart the ) Phoebes , these are they pants ? 
 Phoebe: Yeah , I know , I know . Before I know , I know , but it's kinda far away , but it was worth it . 
 Monica: Yeah , but , I think it would be different if you thought it would be something that you two would talk to . 
 Chandler: I don't think so . 
 Monica: Oh , right . 
 Chandler: Okay . 
 ( They enter . ) 
 Chandler: Don't go . 
 Joey: I love you too Chandler . 
 Chandler: I love you too . 
 Joey: ( to Rachel ) This is everybody . 
 Rachel: Oh , d'you like it ? 
 Ross: Oh , absolutely . . . 
 Rachel: Oh . . . 
 ( Ross looks shocked . ) 
 Ross: That's not what we talked about . 
 Chandler: Oh , I didn't factor in such a year long year ago and I never want to get married outside of work today . 
 Ross: You know , I think we should get a room . 
 Chandl