<a href="https://colab.research.google.com/github/gianluigimazzaglia/SpeechGeneration_friends/blob/main/SpeechGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
# Imports
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn; cudnn.benchmark = True

In [2]:
!gdown --id 1syp8QemrZ4sZtaY-2-DwIXx0630VhK4l

Downloading...
From: https://drive.google.com/uc?id=1syp8QemrZ4sZtaY-2-DwIXx0630VhK4l
To: /content/Friends_Transcript.txt
100% 4.90M/4.90M [00:00<00:00, 117MB/s]


In [3]:
# Options
data_path = "Friends_Transcript.txt"
batch_size = 8
batch_seq_len = 16
embed_size = 512
rnn_size = 1024

In [4]:
# Load data
with open(data_path) as f:
    text = f.read()
# Skip notice
text = text[180:]  # we eliminated the first part of the text that represents just a description meaningless

In [5]:
text[:200]

"Monica: There's nothing to tell! He's just some guy I work with!\nJoey: C'mon, you're going out with the guy! There's gotta be something wrong with him!\nChandler: All right Joey, be nice. So does he ha"

In [6]:
### Replace punctuation with tokens ###
# Create token dictionary
token_dict = {".": "|fullstop|",
              ",": "|comma|",
              "\"": "|quote|",
              ";": "|semicolon|",
              "!": "|exclamation|",
              "?": "|question|",
              "(": "|leftparen|",
              ")": "|rightparen|",
              "--": "|dash|",
              "\n": "|newline|"
}
# Replace punctuation
for punct, token in token_dict.items():
    text = text.replace(punct, f' {token} ')

In [7]:
#Print sample
text[:200]

"Monica: There's nothing to tell |exclamation|  He's just some guy I work with |exclamation|  |newline| Joey: C'mon |comma|  you're going out with the guy |exclamation|  There's gotta be something wron"

In [8]:
### Compute vocabulary ###

# Split words
words = text.split(" ")
# Remove empty words
words = [word for word in words if len(word) > 0]
# Remove duplicates
vocab = list(set(words))  #we have just unique word in vocab

In [9]:
for i, w in enumerate(vocab[:5]):
  print(i, w)

0 EYES
1 shoos
2 glass]
3 Fish
4 Marge:


In [10]:
# Create maps between words
vocab_to_int = {word: i for i,word in enumerate(vocab)}
int_to_vocab = {i: word for i,word in enumerate(vocab)}
#vocab_to_int['bright']   for example this work is a key and it is transformed in integ 1

In [26]:
# Compute number of words
num_words = len(vocab)
print(num_words)  # it is the total of unique words

27750


In [11]:
print(len([word for word in text.split(" ") if len(word) > 0]))

1207503


In [12]:
# Convert text to integers
text_ints = [vocab_to_int[word] for word in text.split(" ") if len(word) > 0] 
text_ints[:5]

[8517, 18790, 1057, 9099, 18225]

In [13]:
len(text_ints) #of course the length of this must be equals to the length of num of word because they are just converted in integer

1207503

In [14]:
# Estimate average scene length
import re

scene = re.findall(r'\[Scene.*?\]', text)

num_scenes = len(scene)
print(len(text_ints)/num_scenes)

389.8944139489829


In [15]:
new_text = [word for word in text.split(" ") if len(word) > 0]
inputs = new_text[:10]
target = new_text[1:10]

print(inputs)
print(target)

['Monica:', "There's", 'nothing', 'to', 'tell', '|exclamation|', "He's", 'just', 'some', 'guy']
["There's", 'nothing', 'to', 'tell', '|exclamation|', "He's", 'just', 'some', 'guy']


In [16]:
# Set scene length (should be multiple of batch_seq_len)
scene_length = 256

In [17]:
# Compute batches
# Needs to be a function so we can compute different batches at different epochs
def get_batches(text_ints, scene_length, batch_size, batch_seq_len):
    # Compute number of "scenes"
    num_scenes = len(text_ints)//scene_length
    # Compute targets for each word (with fake target for final word)
    text_targets = text_ints[1:] + [text_ints[0]]
    # Split text into scenes (input and targets)
    scene_inputs = [text_ints[i * scene_length : (i+1) * scene_length] for i in range(num_scenes)]
    scene_targets = [text_targets[i*scene_length:(i+1)*scene_length] for i in range(num_scenes)]
    # Split scenes into mini-sequences of length batch_seq_len
    num_mini_sequences = scene_length//batch_seq_len
    scene_inputs = [[scene[i*batch_seq_len:(i+1)*batch_seq_len] for i in range(num_mini_sequences)] for scene in scene_inputs]
    scene_targets = [[scene[i*batch_seq_len:(i+1)*batch_seq_len] for i in range(num_mini_sequences)] for scene in scene_targets]
    # Build batches
    num_batch_groups = len(scene_inputs)//batch_size
    batches = []
    for i in range(num_batch_groups):
        # Get the scenes in this group
        group_scene_inputs = scene_inputs[i*batch_size:(i+1)*batch_size]
        group_scene_targets = scene_targets[i*batch_size:(i+1)*batch_size]
        # Build batches for each mini-sequence
        for j in range(num_mini_sequences):
            reset_state = (j == 0)
            batch_inputs = torch.LongTensor([group_scene_inputs[k][j] for k in range(batch_size)])
            batch_targets = torch.LongTensor([group_scene_targets[k][j] for k in range(batch_size)])
            batches.append((reset_state, batch_inputs, batch_targets))
    # Return
    return batches

In [22]:
# Get batches
batches = get_batches(text_ints, scene_length, batch_size, batch_seq_len)
batches[0][1].shape

torch.Size([8, 16])

In [23]:
script = [  int_to_vocab[y.item()] for y in [x for x in batches[1][1][3]] ]
script

['|leftparen|',
 'to',
 'All',
 '|rightparen|',
 'Okay',
 '|comma|',
 'everybody',
 '|comma|',
 'this',
 'is',
 'Rachel',
 '|comma|',
 'another',
 'Lincoln',
 'High',
 'survivor']

In [24]:
# Define model
class Model(nn.Module):
    
    # Constructor
    def __init__(self, num_words, embed_size, rnn_size):
        # Call parent constructor
        super().__init__()
        # Store needed attributes
        self.rnn_size = rnn_size
        self.state = None
        # Define modules
        self.embedding = nn.Embedding(num_words, embed_size)
        self.rnn = nn.LSTM(embed_size, rnn_size, batch_first=True)
        self.decoder = nn.Linear(rnn_size, num_words)
        # Flags
        self.reset_next_state = False
        
    def reset_state(self):
        # Mark next state to be re-initialized
        self.reset_next_state = True
        
    def forward(self, x):
        # Check state reset
        if self.reset_next_state:
            # Initialize state (num_layers x batch_size x rnn_size)
            self.state = (
                x.new_zeros(1, x.size(0), self.rnn_size).float(), 
                x.new_zeros(1, x.size(0), self.rnn_size).float())
            # Clear flag
            self.reset_next_state = False
        # Embed data
        x = self.embedding(x)
        # Process RNN
        state = self.state if self.state is not None else None
        x, state = self.rnn(x, state)
        self.state = (state[0].data, state[1].data)
        # Compute outputs
        x = self.decoder(x)
        return x

In [28]:
# Create model
model = Model(num_words, embed_size, rnn_size)

In [31]:
# Setup device
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [32]:
# Move model to device
model = model.to(dev)

In [33]:
# Define script generation function
def generate_script(model, seq_len, script_start):
    # Convert punctuaction in script start
    for punct, token in token_dict.items():
        script_start = script_start.replace(punct, f' {token} ')
    # Convert script start text to ints
    script_start = [vocab_to_int[word] for word in script_start.split(" ") if len(word) > 0]
    # Initialize output words/tokens
    script = script_start[:]
    # Convert script start to tensor (BxS = 1xS)
    script_start = torch.LongTensor(script_start).unsqueeze_(0)
    # Process script start and generate the rest of the script
    model.eval()
    model.reset_state()
    input = script_start
    for i in range(seq_len - script_start.size(1) + 1): # we include script_start as one of the generation steps
        # Copy input to device
        input = input.to(dev)
        # Pass to model
        output = model(input) # 1xSxV
        # Convert to word indexes
        words = output.max(2)[1] # 1xS
        words = words[0] # S
        # Add each word to script
        for j in range(words.size(0)):
            script.append(words[j].item())
        # Prepare next input
        input = torch.LongTensor([words[-1]]).unsqueeze(0) # 1xS = 1x1
    # Convert word indexes to text
    script = ' '.join([int_to_vocab[x] for x in script])
    # Convert punctuation tokens to symbols
    for punct,token in token_dict.items():
        script = script.replace(f"{token}", punct)
    # Return
    return script

In [34]:
generate_script(model, 20, "Rachel: What?")

"Rachel: What ? himself/herself crazy deaths cheekily Aunt 'good observing Awe 1989 fan Leakey's Nails runs translation It-it's terribly drip Greens Badges Organic"

In [35]:
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [36]:
# Initialize training history
loss_history = []
# Start training
for epoch in range(20):
    # Initialize accumulators for computing average loss/accuracy
    epoch_loss_sum = 0
    epoch_loss_cnt = 0
    # Set network mode
    model.train()
    # Process all batches
    for i,batch in enumerate(batches):
        # Parse batch
        reset_state, input, target = batch
        # Check reset state
        if reset_state:
            model.reset_state()
        # Move to device
        input = input.to(dev)
        target = target.to(dev)
        # Forward
        output = model(input)
        # Compute loss
        output = output.view(-1, num_words)
        target = target.view(-1)
        loss = F.cross_entropy(output, target)
        # Update loss sum
        epoch_loss_sum += loss.item()
        epoch_loss_cnt += 1
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Shift sequence and recompute batches
    shift_point = random.randint(1, len(text_ints)-1)
    text_ints = text_ints[:shift_point] + text_ints[shift_point:]
    batches = get_batches(text_ints, scene_length, batch_size, batch_seq_len)
    # Epoch end - compute average epoch loss
    avg_loss = epoch_loss_sum/epoch_loss_cnt
    print(f"Epoch: {epoch+1}, loss: {epoch_loss_sum/epoch_loss_cnt:.4f}")
    print("Test sample:")
    print("---------------------------------------------------------------")
    print(generate_script(model, scene_length, "Monica:"))
    print("---------------------------------------------------------------")
    # Add to histories
    loss_history.append(avg_loss)

KeyboardInterrupt: ignored

In [None]:
# Generate script
print(generate_script(model, scene_length, "Monica: Really? "))

NameError: ignored