In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel

import argparse
import logging
from tqdm import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [18]:
def format_list(f_list):
    f_list = [str.decode("utf-8") for str in f_list]
    new_f_list = []
    main_index = 0
    sub_str = ""
    while main_index < len(f_list):
        replace_F = str(f_list[main_index]).replace("\n", "")
        if f_list[main_index] == "**EOF**\n":
            sub_str += f" {replace_F}"
            sub_str = ""
            new_f_list.append(sub_str)
        else:
            sub_str += f" {replace_F}"
        main_index += 1
    return new_f_list

In [21]:
from torch.utils.data import Dataset, DataLoader
import gzip

class CommandDataset(Dataset):  
    def __init__(self, fnames, control_code="startoftext", truncate=False, gpt2_type="gpt2", max_length=128):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.lyrics = []
        for fname in fnames:
            f = gzip.open(fname,'r') 
            f = format_list(f)
            for row in f:
                self.lyrics.append(torch.tensor(
                        self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
                    ))               
        if truncate:
            self.lyrics = self.lyrics[:20000]
        self.lyrics_count = len(self.lyrics)
    def __len__(self):
        return self.lyrics_count

    def __getitem__(self, item):
        return self.lyrics[item]
    
dataset = CommandDataset([
    "UNIX_user_data/USER0.gz",
    "UNIX_user_data/USER1.gz", 
    "UNIX_user_data/USER2.gz",
    "UNIX_user_data/USER3.gz", 
    "UNIX_user_data/USER4.gz",
    "UNIX_user_data/USER5.gz",
    "UNIX_user_data/USER6.gz",
    "UNIX_user_data/USER7.gz",
    "UNIX_user_data/USER8.gz",
], truncate=True, gpt2_type="gpt2")      

In [22]:
#Get the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

#Accumulated batch size (since GPT2 is so big)
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

In [None]:
from transformers import get_linear_schedule_with_warmup, AdamW
from tqdm import tqdm, trange
import os
def train(
    dataset, model, tokenizer,
    batch_size=128, epochs=250, lr=2e-5,
    max_seq_len=400, warmup_steps=200,
    gpt2_type="gpt2", output_dir=".", output_prefix="wreckgar",
    test_mode=False,save_model_on_epoch=True,
):
    acc_steps = 100
    device=torch.device("cuda")
    model = nn.DataParallel(model, device_ids = [0,1,2])
    model = model.cuda()
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    loss=0
    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):
        print(f"Training epoch {epoch}")
        print(f"Loss:{loss}")
        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                optimizer.zero_grad()
                model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None
        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"linux-gpt.pt"),
            )
    return model

In [24]:
model = train(dataset, model, tokenizer)

110it [00:00, 685.75it/s]

Training epoch 0
Loss:0


11113it [00:12, 859.00it/s]
220it [00:00, 1352.41it/s]

Training epoch 1
Loss:4.2905707359313965


2640it [00:03, 861.59it/s]


KeyboardInterrupt: 