In [None]:
!nvidia-smi

In [None]:
!pip install transformers

import os
import math
import gc
import json
from google.colab import drive
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import CTRLTokenizer, CTRLLMHeadModel, CTRLConfig

Use this cell only if you ned to connect to a google drive otherwise you can ignore it

In [None]:
# connect to drive

drive.mount('/content/gdrive')
files_dir = "/content/gdrive/My Drive/PRJ/{}"

base_file_dir = files_dir.format("")

%cd "{base_file_dir}"

In [None]:
# setup model

device = "cuda"

layers = 10

print("Creating model...")
model = CTRLLMHeadModel(CTRLConfig(n_layer=layers))
model = model.to(device)
print("Model created.")

print("Loading model checkpoint...")

# load partial trained model
#model.load_state_dict(torch.load('./trained/model_trained_shuffle.bin'))
model.load_state_dict(torch.load('./pytorch_model10.bin'))
print("Model checkpoint loaded")

tokenizer = CTRLTokenizer.from_pretrained('ctrl')
tokenizer.add_special_tokens({'pad_token': '~'})
optimizer =torch.optim.Adagrad(model.parameters(), lr=0.01)

In [None]:
# dataset for new files

def load_coco_data(file_path):
    labels = {} 
    i = 0
    
    if os.path.isfile(file_path):
      print("Load from", file_path)
      
      with open(file_path,"r") as f1:
        for sentence in f1.readlines():
          labels[i] = sentence
          i = i + 1

      print("Dataset loaded")
      return labels
    else:
      print("File not found")
      return labels        


class CocoDataset(Dataset):

  def __init__(self, file_path):
    self.data = load_coco_data(file_path)
    self.len = len(self.data.values())

  def __len__(self):
    return self.len

  def __getitem__(self, index):
    toRet = self.data[index]
    toRet = str(toRet)
    return toRet

# load dataset from file path
dataset = CocoDataset("data_shuffle/file12.txt")

In [None]:
# Training

data_loader = DataLoader(dataset)

x=0
batch_s = 10
l = []
model.train()

for i, data in enumerate(data_loader):
  
  # Step 1: Retrieving a batch of input from the dataloader
  
  l.append(data[0]) # accumulate batch_s sentences

  if ((i%batch_s == 0 and i != 0) or i == (dataset.len - 1)): # reach batch_s -> do training step
    
    input = tokenizer(l, return_tensors='pt', padding = True, truncation=True).to(device)
    # print(input["input_ids"])
    l.clear()
  else:
    continue # skip to next sentence    
  
  # Step 2: Zeroing the parameter gradients - always do this before doing loss.backward()!!!
  optimizer.zero_grad()
      
  # Step 3: forward (get network prediction)
  outputs = model(**input, labels = input["input_ids"])
  
  # free GPU from input data
  del input
  torch.cuda.empty_cache()
  gc.collect()

  # generation of logits
  logits = outputs.logits
  
  # Step 4: compute loss
  loss = outputs.loss
  
  # Step 5: Compute gradients for each of the model learnable parameters
  loss.backward()
  
  # Step 4: Update model parameters according to the gradients
  optimizer.step()

  # print after n steps
  x = x + 1
  if (x%500 == 0 and x != 0):
    print(i)
   
  
# End Training here
print("Training done")

print("Saving model...")
torch.save(model.state_dict(), "/content/gdrive/My Drive/PRJ/trained/model_from_scratch.bin")
print("Model saved")

In [None]:
# generation with model already loaded

print('INFO :: Start!!')

# setup
seq_length = 30
temperature = 1.0 #default=1.0
nucleusprob = 0.9 #default=0.9
penalty = 1.2     #help="primarily useful for CTRL model; in that case, use 1.2"
topk = 0          #default=0

# get prompt
# Vehicle A man sitting on his motorcycle with one hand on his helmet. 
# Food Several dishes contain a wide variety of vegetables.
# Outdoor Two women walk across a street-crossing in a city with cars and a bus driving by.
# Person A man stands barefoot on a sandy beach.
# Electronic A boy using his phone and fanning himself.
prompt = input('ENTER PROMPT: ')

encoded_CTRL = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
encoded_input = encoded_CTRL.to(device)

len_prompt = len(encoded_input[0])

# generation of logits
output_sequence = model.generate(
  input_ids=encoded_input,
  max_length= seq_length + len_prompt,
  temperature=temperature,
  top_k=topk,
  top_p=nucleusprob,
  repetition_penalty=penalty,
  do_sample=True,
  num_return_sequences=1,
)

# get text from logits
generated_sequence = output_sequence[0].tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

# Remove all text after the stop token
if ("." in text):
  text = text[: text.index(".")+1]

if ("~" in text):
  text = text[: text.index("~")]

if ("\n" in text):
  text = text[: text.index("\n")]

print("=== GENERATED SEQUENCE ===")
print(text)

print('INFO :: Generation done!')