In [1]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path+"\\datasets")

In [2]:
from wikipedia.extract import extract

df_train, train_stats = extract("../datasets/wikipedia/raw/wikipedia.train")
df_test, test_stats = extract("../datasets/wikipedia/raw/wikipedia.test")

In [3]:
train = df_train[['text', 'subject']].copy()
train.columns = ['input_text', 'target_text']
train["input_text"] = train["input_text"].apply(lambda x: " ".join(x))
train["prefix"] = "subj_id"

test = df_test[['text', 'subject']].copy()
test.columns = ['input_text', 'target_text']
test["input_text"] = test["input_text"].apply(lambda x: " ".join(x))
test["prefix"] = "subj_id"

print(train.head())
print(test.head())

                                          input_text        target_text  \
0  In recognition of Darwin s pre eminence he was...     Charles Darwin   
1  Adams s most important contributions to Americ...  John Quincy Adams   
2  LaWanda Page October 19 1920 September 14 2002...       LaWanda Page   
3  Born in Randolph County North Carolina Jordan ...  B. Everett Jordan   
4  Aldous Leonard Huxley July 26 1894 November 22...      Aldous Huxley   

    prefix  
0  subj_id  
1  subj_id  
2  subj_id  
3  subj_id  
4  subj_id  
                                          input_text      target_text   prefix
0  Dick Cheney s public service career began unde...      Dick Cheney  subj_id
1  Earling Carothers Garrison November 20 1921 Oc...     Jim Garrison  subj_id
2  In 1816 when Lincoln was seven years old he an...  Abraham Lincoln  subj_id
3  Clive Jay Davis born April 4 1934 is the found...      Clive Davis  subj_id
4  Carroll was formerly the head coach of the New...     Pete Carroll  subj_

In [3]:
from simpletransformers.t5 import T5Model, T5Args

model_args = T5Args()
model_args.num_train_epochs = 5
model_args.fp16 = False
model_args.overwrite_output_dir = True
model_args.use_early_stopping = True
model_args.early_stopping_delta = 0.01
model_args.early_stopping_patience = 5
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = True

# change this to t5-base for training
model = T5Model("outputs/best_model", args=model_args, use_cuda=False)

def avg_sim(labels, preds):
    return sum([ textdistance(labels[i], preds[i]) for i in range(len(preds))]) / len(preds)

In [None]:
# For training
model.train_model(train, eval_data=test, matches=avg_sim)

In [None]:
# For evaluating
model.eval_model(test, matches=avg_sim)

In [6]:
# Example
predictions = model.predict([
    "subj_id: Atlantis (Ancient Greek: Ἀτλαντὶς νῆσος, 'island of Atlas') is a fictional island mentioned in an allegory on the hubris of nations in Plato's works Timaeus and Critias, where it represents the antagonist naval power that besieges 'Ancient Athens', the pseudo-historic embodiment of Plato's ideal state in The Republic. In the story, Athens repels the Atlantean attack unlike any other nation of the known world,[2] supposedly bearing witness to the superiority of Plato's concept of a state. The story concludes with Atlantis falling out of favor with the deities and submerging into the Atlantic Ocean.",
    "subj_id: It is fairly certain that, at some point, he went to Kusumapura for advanced studies and lived there for some time. Both Hindu and Buddhist tradition, as well as Bhāskara I (CE 629), identify Kusumapura as Pāṭaliputra, modern Patna. A verse mentions that Aryabhata was the head of an institution (kulapa) at Kusumapura, and, because the university of Nalanda was in Pataliputra at the time and had an astronomical observatory, it is speculated that Aryabhata might have been the head of the Nalanda university as well. Aryabhata is also reputed to have set up an observatory at the Sun temple in Taregana, Bihar.",
    "subj_id: Betty Botter bought some butter from the market but the butter was bitter and made her batter butter bitter so Betty bought some better butter to make the bitter batter butter better."
])

print(predictions)

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Decoding outputs', max=3.0, style=ProgressStyle(descripti…


['Atlantis', 'Aryabhata', 'Betty Botter']
