In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [None]:
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob)
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import os
import json
import csv
import pandas as pd
from google.colab import files
class JokesDataset(Dataset):
    def __init__(self):
        super().__init__()
        uploaded = files.upload()
        self.joke_list = []
        self.end_of_text_token = "<|endoftext|>"
        with open(list(uploaded.keys())[0]) as csv_file:
            csv_reader = pd.read_csv(csv_file, delimiter=',',skip_blank_lines=True)
            print(csv_reader)
            x = 0
            for row in csv_reader:
                joke_str = f"JOKE:{row[1]}{self.end_of_text_token}"
                print(joke_str)
                self.joke_list.append(joke_str)

    def __len__(self):
        return len(self.joke_list)

    def __getitem__(self, item):
        return self.joke_list[item]

In [None]:
dataset = JokesDataset()
joke_loader = DataLoader(dataset, batch_size=1, shuffle=True)

Saving shortjokes.csv to shortjokes (25).csv
     JOKE:The new iPhone 7 will be called "The iPhone 8." It'll be called the iPhone 8 because that's the name of the phone.<|endoftext|>
0                                                   NaN                                                                                  
1     JOKE:What do you call it if you have a dog wit...                                                                                  
2                                                   NaN                                                                                  
3     JOKE:What's the difference between the Pope an...                                                                                  
4                                                   NaN                                                                                  
...                                                 ...                                                                        

In [None]:
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-5
WARMUP_STEPS = 5000
MAX_SEQ_LEN = 400
from transformers import AdamW

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'


In [None]:
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_jokes_tens = None
models_folder = "trained_models"
if not os.path.exists(models_folder):
    os.mkdir(models_folder)

for epoch in range(EPOCHS):

    print(f"EPOCH {epoch} started" + '=' * 30)

    for idx,joke in enumerate(joke_loader):

        joke_tens = torch.tensor(tokenizer.encode(joke[0])).unsqueeze(0).to(device)
        if joke_tens.size()[1] > MAX_SEQ_LEN:
            continue

        if not torch.is_tensor(tmp_jokes_tens):
            tmp_jokes_tens = joke_tens
            continue
        else:
            if tmp_jokes_tens.size()[1] + joke_tens.size()[1] > MAX_SEQ_LEN:
                work_jokes_tens = tmp_jokes_tens
                tmp_jokes_tens = joke_tens
            else:
                tmp_jokes_tens = torch.cat([tmp_jokes_tens, joke_tens[:,1:]], dim=1)
                continue

        outputs = model(work_jokes_tens, labels=work_jokes_tens)
        loss, logits = outputs[:2]
        loss.backward()
        sum_loss = sum_loss + loss.detach().data

        proc_seq_count = proc_seq_count + 1
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0
            batch_count += 1
            optimizer.step()
            optimizer.zero_grad()
            model.zero_grad()

        if batch_count == 100:
            print(f"sum loss {sum_loss}")
            batch_count = 0
            sum_loss = 0.0

    # Store the model after each epoch to compare the performance of them
    torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_medium_joker_{epoch}.pt"))



In [25]:
MODEL_EPOCH = 4

models_folder = "trained_models"

model_path = os.path.join(models_folder, f"gpt2_medium_joker_{MODEL_EPOCH}.pt")
model.load_state_dict(torch.load(model_path))

jokes_output_file_path = f'generated_{MODEL_EPOCH}.jokes'

model.eval()
if os.path.exists(jokes_output_file_path):
    os.remove(jokes_output_file_path)

joke_num = 0
with torch.no_grad():

        for joke_idx in range(2):
            print(joke_idx)
            print(30*'==')

            joke_finished = False

            cur_ids = torch.tensor(tokenizer.encode("JOKE:")).unsqueeze(0).to(device)

            for i in range(100):
                print(i)
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0)
                if i < 3:
                    n = 20
                else:
                    n = 3
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n)
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1)

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    joke_finished = True
                    break


            if joke_finished:

                joke_num = joke_num + 1

                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)

                with open(jokes_output_file_path, 'a') as f:
                    f.write(f"{output_text} \n\n")


0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
1
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [26]:
!ls

 sample_data	       'shortjokes (16).csv'  'shortjokes (22).csv'  'shortjokes (5).csv'
'shortjokes (10).csv'  'shortjokes (17).csv'  'shortjokes (23).csv'  'shortjokes (6).csv'
'shortjokes (11).csv'  'shortjokes (18).csv'  'shortjokes (24).csv'  'shortjokes (7).csv'
'shortjokes (12).csv'  'shortjokes (19).csv'  'shortjokes (25).csv'  'shortjokes (8).csv'
'shortjokes (13).csv'  'shortjokes (1).csv'   'shortjokes (2).csv'   'shortjokes (9).csv'
'shortjokes (14).csv'  'shortjokes (20).csv'  'shortjokes (3).csv'    shortjokes.csv
'shortjokes (15).csv'  'shortjokes (21).csv'  'shortjokes (4).csv'    trained_models
