### Import libraries

In [2]:
import pandas as pd
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.optimization import Adafactor
import time
import warnings
from IPython.display import HTML, display
warnings.filterwarnings('ignore')

In [3]:
torch.cuda.empty_cache()

### Build required functions inside a class

In [4]:
class ContentGenerator():
    def __init__(self):
        self.batch_size = 8
        self.num_of_epochs = 5
        self.checkpoint_path = '../checkpoint/business-description-generator-model.bin'
        self.config_path = '../data/t5-base-config.json'
    
    def getData(self):
        self.train_df = pd.read_csv('data/company-nlg-data.csv')
        self.num_of_batches = int(len(self.train_df)/self.batch_size)
        
    def getDevice(self):
        if torch.cuda.is_available():
            self.dev = torch.device("cuda:0")
            print("Running on the GPU")
        else:
            self.dev = torch.device("cpu")
            print("Running on the CPU")
    
    def getTokenizer(self):
        self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
    
    def getModel(self):
        self.getDevice()
        self.getTokenizer()
        self.model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
        self.model.to(self.dev)
        self.optimizer = Adafactor(
            self.model.parameters(),
            lr=1e-3,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=None,
            weight_decay=0.0,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False
        )
        
    def progress(self, loss,value, max=100):
        return HTML(""" Batch loss :{loss}
            <progress
                value='{value}'
                max='{max}',
                style='width: 100%'
            >
                {value}
            </progress>
        """.format(loss=loss,value=value, max=max))
    
    def fit(self):
        self.getModel()
        self.model.train()
        self.train_df = self.train_df.dropna()
        loss_per_10_steps=[]
        for epoch in range(1, self.num_of_epochs+1):
            print('Running epoch: {}'.format(epoch))
            running_loss=0

            out = display(self.progress(1, self.num_of_batches+1), display_id=True)
            for i in range(self.num_of_batches):
                inputbatch = []
                labelbatch = []
                new_df = self.train_df[i*self.batch_size:i*self.batch_size+self.batch_size]
                for indx,row in new_df.iterrows():
                    input = 'WebNLG: '+row['input_text']+'</s>' 
                    labels = row['target_text']+'</s>'   
                    inputbatch.append(input)
                    labelbatch.append(labels)

                if len(inputbatch) < 1:
                    continue
                inputbatch = self.tokenizer.batch_encode_plus(inputbatch,padding=True,max_length=400,return_tensors='pt')["input_ids"]
                labelbatch = self.tokenizer.batch_encode_plus(labelbatch,padding=True,max_length=400,return_tensors="pt") ["input_ids"]
                inputbatch = inputbatch.to(self.dev)
                labelbatch = labelbatch.to(self.dev)

                self.optimizer.zero_grad()

                outputs = self.model(input_ids=inputbatch, labels=labelbatch)
                loss = outputs.loss
                loss_num = loss.item()
                logits = outputs.logits
                running_loss += loss_num
                if i%10 ==0:
                    loss_per_10_steps.append(loss_num)
                out.update(self.progress(loss_num, i, self.num_of_batches+1))
                loss.backward()
                self.optimizer.step()

            running_loss = running_loss/int(self.num_of_batches)
            print('Epoch: {} , Running loss: {}'.format(epoch,running_loss))
        self.saveModel()
        self.emptyCudaCache()

    def saveModel(self):
        torch.save(self.model.state_dict(), self.checkpoint_path)
    
    def loadModel(self):
        return T5ForConditionalGeneration.from_pretrained(self.checkpoint_path, return_dict=True, config=self.config_path)
    
    def emptyCudaCache(self):
        torch.cuda.empty_cache()
    
    def generate(self, text):
        torch.manual_seed(0)
        model = self.loadModel()
        model.eval()
        input_ids = self.tokenizer.encode("WebNLG:{} </s>".format(text), return_tensors="pt")
        sample_outputs = model.generate(
            input_ids,
            do_sample=True,
            max_length=50, 
            top_k=4, 
            top_p=0.99,
            num_return_sequences=10
        )

        print("Output:\n" + 100 * '-')
        for i, sample_output in enumerate(sample_outputs):
          print("{}: {}".format(i, self.tokenizer.decode(sample_output, skip_special_tokens=True)))

### Initialize class object

In [27]:
obj = ContentGenerator()

### Train the pretrained model on custom dataset

In [7]:
obj.fit()

### Test: Generate Google ads suggestion based on company description's keywords

In [5]:
obj = ContentGenerator()
obj.getTokenizer()

In [6]:
key = 'Pied Piper | blockchain | new internet'
obj.generate(key)

Output:
----------------------------------------------------------------------------------------------------
0: Pied Piper builds blockchain and empowers people to invest in assets that belong to the new Internet.
1: Pied Piper is bringing blockchain to the new, secure, and decentralized Internet.
2: Pied Piper builds blockchain based tools for the new, fully regulated internet.
3: Pied Piper is a blockchain-powered, decentralized marketplace for the next generations of internet-connected assets.
4: Pied Piper is bringing blockchain to the new, decentralized Internet.
5: Pied Piper enables blockchain developers to build, operate, and scale to the new, open Internet of things.
6: Pied Piper enables blockchain and blockchain developers to build secure, secure and decentralized platforms for the new, uncensored internet.
7: Pied Piper provides blockchain and other services to the new, uncensored internet.
8: Pied Piper is a blockchain based tools and services for the new, secure, decentra