This notebook serves as a guide to fine-tune an LLM. 

Our example will fine-tune a T5 model using summary-source pairs. The dataset we will use can be found in the "data/" directory of this folder. 

We will use three notebooks for this tutorial:
1. Cleaning the data
2. Preparing the dataset and model configurations
3. Initiating the training instance and training our model

In [4]:
cd ..

/home/eanthony/workspace/github-work/aidiv-sagemaker-examples


## Package Imports

In [5]:
import torch
from vizzy import *
import pandas as pd
import os
import tarfile

## Visualizing and Cleaning the Data

The first thing we need to do is unpack the .tgz file your data is stored in:

In [6]:
FILENAME = 'cnn_stories.tgz'
FILEPATH = 'data/'

tar = tarfile.open(os.path.join(FILEPATH,FILENAME))

In [7]:
tar.extractall('data/cnn_stories')

You should now see a directory called "cnn_stories" that contains .story files

In [8]:
story_filepath = 'data/cnn_stories/cnn/stories'
story_list = os.listdir(story_filepath)

Let's take a look at some of our stories to see what our data looks like

In [9]:
for file in story_list[0:5]:
    f = open(os.path.join(story_filepath, file), encoding = 'utf-8')
    text = f.read()
    print(text)
    print('\n\n')

ASUNCION, Paraguay (CNN)  -- Paraguayan President Fernando Lugo acknowledged Monday that he is the father of a 2-year-old child who was conceived when he was a Roman Catholic bishop.

Paraguayan President Fernando Lugo speaks at a news conference in Asuncion on Monday.

"It's true that there was a relationship with Viviana Carrillo," Lugo told reporters, citing the mother. "I assume all the responsibilities that could derive from such an act, recognizing the paternity of the child."

He said he was making the acknowledgment "with the most absolute honesty, transparency and feeling of obligation."

The announcement came in the week after Carrillo had filed suit in a city in southern Paraguay seeking a paternity test.

Judge Evelyn Peralta, who is overseeing the case, said she was treating it routinely. "It is a case like any other, which involves the president and nothing more," she said. "It will be processed as it should be."

Some Cabinet members interpreted Lugo's acknowledgment of 

We will want to create some functions to stream in every document and do some cleaning. 

Write a function which creates an object for each of our stories so that we can clean them later, or use the one below. We also see "@highlight" at the end of each of our stories which signifies a summary or highlight of the story. 

Write a function to pull those highlights out to use as labels for our data, or use the one below.

Finally, we need to remove all non-unicode characters and join together any white space or line breaks in our text. Write a function to perform these actions, or use the one below

In [10]:
def split_story(text):
    """
    Splits a given text into a story and its highlights.

    Args:
        text (str): The text to split.

    Returns:
        tuple: A tuple containing two strings - the story and its highlights.
        
    Example:
        If we have a text string containing a story and its highlights:
        "This is the story. @highlight This is a highlight. @highlight This is another highlight."
        Calling split_story(text) will return:
        ("This is the story.", ["This is a highlight.", "This is another highlight."])
    """
    idx = text.find('@highlight')
    story, highlights = text[:idx], text[idx:].split('@highlight')
    #strip whitespace from hightlights
    highlights = [h.strip() for h in highlights if len(h) > 0]
    return story, highlights

def clean_text(text):
    """
    Removes newline characters and non-ASCII characters from a given text.

    Args:
        text (str): The text to clean.

    Returns:
        str: The cleaned text.
        
    Example:
        If we have a text string containing newlines and non-ASCII characters:
        "This is some text.\nIt has a € symbol."
        Calling clean_text(text) will return:
        "This is some text. It has a  symbol."
    """
    text = text.replace('\n', ' ')
    text = text.replace(' -- ', '')
    text = text.replace('(CNN)', '')
    return ''.join(t for t in text if ord(t) < 128)
    
def load_stories(directory):
    """
    Loads stories from a directory and returns them in a list.

    Args:
        directory (str): The path to the directory containing the .txt or .story files of the stories.

    Returns:
        list: A list of dictionaries, where each dictionary contains a story and its highlights.

    Raises:
        FileNotFoundError: If the directory specified does not exist.
        IsADirectoryError: If the path specified is a directory and not a file.
        UnicodeDecodeError: If the file contains non-unicode text that cannot be decoded.
        
    Example:
        If we have a directory "stories" containing the following files:
        - story1.txt with contents "This is story 1. @highlight This is a highlight."
        - story2.txt with contents "This is story 2. @highlight This is another highlight."
        Calling load_stories("stories") will return:
        [
            {'story': 'This is story 1.', 'highlight': ['This is a highlight.']},
            {'story': 'This is story 2.', 'highlight': ['This is another highlight.']}
        ]
    """
    all_stories = []
    for file in os.listdir(directory):
        if file.endswith('.story'):
            f = open(os.path.join(directory, file), encoding = 'utf-8')
            doc = f.read()
            doc = clean_text(doc)
            story, highlights = split_story(doc)
            #Add the story and highlight
            all_stories.append({'story': story, 'highlight': highlights})
    return all_stories

Let's create our data object

In [11]:
data = load_stories(story_filepath)

Let's create a vizzy object to get some more visualizations of our text

In [12]:
df = pd.DataFrame(data)
viz = vizzy.vizzy_sentence(df, 'story')

In [13]:
viz.print_word_count()

The average number of words in your text cells is 647.3588718823924
The max number of words in your text cells is 1876
The smallest number of words in your text cells is 0


We can see that we have some zero values in our text cells. Write a function which will help you identify those rows, or use the function below. After looking at those lines, let's drop those rows.

## Vizualize Outliers

In [14]:
for idx, line in enumerate(df.story):
    if len(line) < 5:
        df = df.drop(idx)
        print(idx, line)
        
df = df.reset_index()

1333   
1919   
2966   
3426   
4403   
5544   
6987   
7890   
8537   
9492   
9603   
9665   
9939   
11011   
11168   
11812   
12674   
13818   
14081   
14232   
15109   
16671   
17445   
20269   
22086   
22888   
23345   
23683   
24360   
25357   
25823   
27694   
29156   
29179   
30380   
31255   
31335   
32198   
32571   
32794   
34576   
34989   
36809   
37103   
37178   
37334   
41737   
41933   
42140   
42726   
42968   
43674   
44840   
46297   
47891   
48960   
51111   
53949   
55573   
57099   
58065   
59963   
60609   
62570   
63506   
63507   
63946   
65083   
65495   
66140   
66269   
66579   
67844   
69117   
69186   
72112   
72434   
72444   
74601   
75212   
75539   
75834   
76690   
79784   
79843   
80084   
80391   
80685   
81602   
82124   
82230   
82295   
82651   
83453   
84175   
84384   
84499   
84772   
85208   
86020   
86289   
86660   
86874   
87370   
87830   
88036   
89178   
89435   
89643   
91150   
91359   
91892   
92060

Let's look again at our data and see if we got rid of our null values

In [15]:
viz = vizzy.vizzy_sentence(df, 'story')
viz.print_word_count()

The average number of words in your text cells is 648.1569999459255
The max number of words in your text cells is 1876
The smallest number of words in your text cells is 14


In [17]:
viz.topic_model()

  by='saliency', ascending=False).head(R).drop('saliency', 1)


It looks like we were able to remove all of our null values. Let's drop our old index column and save our data as a pickle for use with our next notebook

In [18]:
df.drop(columns=['index', 'text_without_stopwords'], inplace=True)

## Save to Pickle

In [19]:
df.to_pickle('cnn_stories_clean.pkl')