<a href="https://colab.research.google.com/github/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/student/W1D1_Tutorial1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> &nbsp; <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/student/W1D1_Tutorial1.ipynb" target="_parent"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open in Kaggle"/></a>

# Tutorial 1: Generalization in AI

**Week 1, Day 1: Generalization**

**By Neuromatch Academy**

__Content creators:__ Samuele Bolotta & Patrick Mineault

__Content reviewers:__ Samuele Bolotta

__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk

<br>


___


# Tutorial Objectives

*Estimated timing of tutorial: [insert estimated duration of the whole tutorial in minutes]*

This tutorial will introduce you to generalization in the context of modern AI systems. We'll look at a particular system trained for handwriting recognition–TrOCR. We'll review what makes that model tick–the transformer architecture–and explore what goes on into training and finetuning large-scale models. We'll look at how augmentations can bake in particular inductive biases in transformers. Finally, we'll have a bonus section on scaling laws.

Our learning objectives for today are:

1. Identify and articulate common objectives pursued by developers of operational AI systems, such as:

- OOD robustness; latency; Size, Weight, Power, and Cost (SWaP-C)
- Explainability and understanding

2. Explain at least three strategies for enhancing the generalization capabilities of AI systems, including the contemporary trend of training generic large-scale models on extensive datasets, commonly referred to as the "bitter lesson."

3. Gain practical experience with the fundamentals of deep learning and PyTorch.

**Important note**: this tutorial leverages GPU acceleration. Using a GPU runtime in colab will make the the tutorial run 10x faster.

Let's get started!

In [None]:
# @title Tutorial slides
# @markdown These are the slides for the videos in all tutorials today

#from IPython.display import IFrame
#link_id = "<YOUR_LINK_ID_HERE>"

print("If you want to download the slides: 'Link to the slides'")
      # Example: https://osf.io/download/{link_id}/

#IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

---
# Setup



In [None]:
# @title Install dependencies

#!pip install numpy Pillow matplotlib torch torchvision transformers gradio sentencepiece protobuf --quiet
# !pip install sentencepiece gradio torchmetrics --quiet

In [None]:
# @title Import dependencies

# Standard Libraries for file and operating system operations, security, and web requests
import os
import functools
import hashlib
import requests
import logging
import io
import re
import time

# Core python data science and image processing libraries
import numpy as np
from PIL import Image as IMG
from PIL import ImageDraw, ImageFont
import matplotlib.pyplot as plt
import tqdm

# Deep Learning and model specific libraries
import torch
import torchmetrics.functional.text as fm
import transformers
from torchvision import transforms
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Utility and interface libraries
import gradio as gr
from IPython.display import IFrame, display, Image
import sentencepiece
import zipfile
import pandas as pd


device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

In [None]:
# @title Plotting functions

def display_image(image_path):
    """Display an image from a given file path.

    Inputs:
    - image_path (str): The path to the image file.
    """
    # Open the image
    image = Image.open(image_path)
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Display the image
    plt.imshow(image)
    plt.axis('off')  # Turn off the axis
    plt.show()

def display_transformed_images(image, transformations):
    """
    Apply a list of transformations to an image and display them.

    Inputs:
    - image (Tensor): The input image as a tensor.
    - transformations (list): A list of torchvision transformations to apply.
    """
    # Convert tensor image to PIL Image for display
    pil_image = transforms.ToPILImage()(image)

    fig, axs = plt.subplots(len(transformations) + 1, 1, figsize=(5, 15))
    axs[0].imshow(pil_image, cmap='gray')
    axs[0].set_title('Original')
    axs[0].axis('off')

    for i, transform in enumerate(transformations):
        # Apply transformation if it's not the placeholder
        if transform != "Custom ElasticTransform Placeholder":
            transformed_image = transform(image)
            # Convert transformed tensor image to PIL Image for display
            display_image = transforms.ToPILImage()(transformed_image)
            axs[i+1].imshow(display_image, cmap='gray')
            axs[i+1].set_title(transform.__class__.__name__)
            axs[i+1].axis('off')
        else:
            axs[i+1].text(0.5, 0.5, 'ElasticTransform Placeholder', ha='center')
            axs[i+1].axis('off')

    plt.tight_layout()
    plt.show()

def display_original_and_transformed_images(original_tensor, transformed_tensor):
    """
    Display the original and transformed images side by side.

    Inputs:
    - original_tensor (Tensor): The original image as a tensor.
    - transformed_tensor (Tensor): The transformed image as a tensor.
    """
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    # Display original image
    original_image = original_tensor.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
    axs[0].imshow(original_image, cmap='gray')
    axs[0].set_title('Original')
    axs[0].axis('off')

    # Display transformed image
    transformed_image = transformed_tensor.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
    axs[1].imshow(transformed_image, cmap='gray')
    axs[1].set_title('Transformed')
    axs[1].axis('off')

    plt.show()

def display_generated_images(generator):
    """
    Display images generated from strings.

    Inputs:
    - generator (GeneratorFromStrings): A generator that produces images from strings.
    """
    plt.figure(figsize=(15, 3))
    for i, (text_img, lbl) in enumerate(generator, 1):
        ax = plt.subplot(1, len(generator.strings) * generator.count // len(generator.strings), i)
        plt.imshow(text_img)
        plt.title(f"Example {i}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()


# Function to generate an image with text
def generate_image(text, font_path, space_width=2, skewing_angle=8):
    """Generate an image with text.

    Args:
        text (str): Text to be rendered in the image.
        font_path (str): Path to the font file.
        space_width (int): Space width between characters.
        skewing_angle (int): Angle to skew the text image.
    """
    image_size = (350, 50)
    background_color = (255, 255, 255)
    speckle_threshold = 0.05
    speckle_color = (200, 200, 200)
    background = np.random.rand(image_size[1], image_size[0], 1) * 64 + 191
    background = np.tile(background, [1, 1, 4])
    background[:, :, -1] = 255
    image = IMG.fromarray(background.astype('uint8'), 'RGBA')
    image2 = IMG.new('RGBA', image_size, (255, 255, 255, 0))
    draw = ImageDraw.Draw(image2)
    font = ImageFont.truetype(font_path, size=36)
    text_size = draw.textlength(text, font=font)
    text_position = ((image_size[0] - text_size) // 2, (image_size[1] - font.size) // 2)
    draw.text(text_position, text, font=font, fill=(0, 0, 0), spacing=space_width)
    image2 = image2.rotate(skewing_angle)
    image.paste(image2, mask=image2)
    return image

# Function to generate images for multiple strings
def image_generator(strings, font_path, space_width=2, skewing_angle=8):
    """Generate images for multiple strings.

    Args:
        strings (list): List of strings to generate images for.
        font_path (str): Path to the font file.
        space_width (int): Space width between characters.
        skewing_angle (int): Angle to skew the text image.
    """
    for text in strings:
        yield generate_image(text, font_path, space_width, skewing_angle)

In [None]:
def download_file(fname, url, expected_md5):
    """
    Downloads a file from the given URL and saves it locally.
    Verifies the integrity of the file using an MD5 checksum.

    Args:
    - fname (str): The local filename/path to save the downloaded file.
    - url (str): The URL from which to download the file.
    - expected_md5 (str): The expected MD5 checksum to verify the integrity of the downloaded data.
    """
    if not os.path.isfile(fname):
        try:
            r = requests.get(url)
            r.raise_for_status()  # Raises an HTTPError for bad responses
        except (requests.ConnectionError, requests.HTTPError) as e:
            print(f"!!! Failed to download {fname} due to: {str(e)} !!!")
            return
        if hashlib.md5(r.content).hexdigest() == expected_md5:
            with open(fname, "wb") as fid:
                fid.write(r.content)
            print(f"{fname} has been downloaded successfully.")
        else:
            print("!!! Data download appears corrupted !!!")

def extract_zip(zip_fname, folder='.'):
    """
    Extracts a ZIP file to the specified folder.

    Args:
    - zip_fname (str): The filename/path of the ZIP file to be extracted.
    - folder (str): Destination folder where the ZIP contents will be extracted.
    """
    if zipfile.is_zipfile(zip_fname):
        with zipfile.ZipFile(zip_fname, 'r') as zip_ref:
            zip_ref.extractall(folder)
            print(f"Extracted {zip_fname} to {folder}.")
    else:
        print(f"Skipped extraction for {zip_fname} as it is not a zip file.")

# Define the list of files to download, including both ZIP files and other file types
file_info = [
    ("Dancing_Script.zip", "https://osf.io/32yed/download", "d59bd3201b58a37d0d3b4cd0b0ec7400", '.'),
    ("lines.zip", "https://osf.io/8a753/download", "6815ed3987f8eb2fd3bc7678c11f2e9e", 'lines'),
    ("transcripts.csv", "https://osf.io/9hgr8/download", "d81d9ade10db55603cc893345debfaa2", None)  # No extraction needed
]

# Process the downloads and extractions
for fname, url, expected_md5, folder in file_info:
    download_file(fname, url, expected_md5)
    if folder is not None:
        extract_zip(fname, folder)

# Define the list of new images to download
image_info = [
    ("sample_0.png", "https://osf.io/j5ckg/download", "920ae567f707bfee0be29dc854f804ed"),
    ("sample_1.png", "https://osf.io/rfys9/download", "cd28623a829b40d0a1dd8c0f17e9ebd7"),
    ("sample_2.png", "https://osf.io/jsrzv/download", "c189c09abf989eac4e1a8d493bd362d7"),
    ("sample_3.png", "https://osf.io/m87cf/download", "dcffc678266952f18af1fc1242127e98"),
    ("transformer_one_layer.png", "https://osf.io/ak9bm/download", "e5f8407112525c4cd5722ac9700c5948"),
    ("trocr_architecture.png", "https://osf.io/aks8y/download", "6222eb05d3b37c9bf2057271be5bb627"),
    ("W1D1_goal.png", "https://osf.io/sek25/download", "7632703b00f7b0063dccfb519a54e526"),
    ("neuroai_hello_world.png", "https://osf.io/zg4w5/download", "f08b81e47f2fe66b5f25b2ccc204c780")
]

# Download the new images
for fname, url, expected_md5 in image_info:
    download_file(fname, url, expected_md5)

# Section 1: Motivation: building a handwriting recognition app with AI

# Video

In [None]:
# @markdown
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'kmzUvUYf8M4'), ('Bilibili', '')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

Let’s put ourselves into the mindset of an AI developer who wants to build a note app featuring handwriting recognition.

In [None]:
# @title Show image
# @markdown

display(Image(filename="W1D1_goal.png"))

Our intrepid goes on HuggingFace and finds a suitable model: [TrOCR](https://huggingface.co/docs/transformers/en/model_doc/trocr)! It's a Transformer-based model that performs Optical Character Recognition and handwriting transcription. Several checkpoints are available, finetuned for different downstream applications like handwriting transcription and printed character recognition. Our relieved developer draws a deep sigh: they don't have to start from scratch.

In [None]:
# @title Show image
# @markdown

display(Image(filename="trocr_architecture.png"))

In this tutorial, we'll look at the design considerations that go into training and deploying a model like TrOCR, what goes on inside the model's transformers, and how it achieves good–or sometimes not-so-good–out-of-distribution generalization. While the NeuroAI course as a whole will explore new ideas at the frontier of neuroscience and AI, we'll first want to understand one of the bread-and-butter building blocks used in industrial AI: the transformer.

Let's try out this model ourselves!

## Interactive demo 1: TrOCR

We load a pretrained TrOCR checkpoint from HuggingFace.

In [None]:
# Load the pre-trained TrOCR model and processor
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.to(device=device)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten", use_fast=False)

We write a callback function that calls the preloaded model.

In [None]:
# Define the function to recognize text from an image
def recognize_text(processor, model, image):
    """
    This function takes an image as input and uses a pre-trained language model to generate text from the image.

    Inputs:
    - processor: The processor to use
    - model: The model to use
    - image (PIL Image or Tensor): The input image containing text to be recognized.

    Outputs:
    - text (str): The recognized text extracted from the input image.
    """
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values.to(device))
    text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return text

We build a simple interface in gradio to try out the model interactively.

In [None]:
# Create a Gradio interface
interface = gr.Interface(
    fn=functools.partial(recognize_text, processor, model),
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(),
    title="Interactive demo: TrOCR",
    description="Demo for Microsoft’s TrOCR, an encoder-decoder model for OCR on single-text line images.",
)

# Launch the interface
interface.launch()

Go ahead and try some example text to see how it works. You can use images from the internet, or scan your own handwriting. Just make sure that the text fits on one line.

In [None]:
# @title Show image
# @markdown

display(Image(filename="sample_0.png"))
display(Image(filename="sample_1.png"))
display(Image(filename="sample_2.png"))
display(Image(filename="sample_3.png"))

### Discussion point

How effective is the model's performance? Does it exhibit generalization beyond its training vocabulary?

# Section 2: Measuring out-of-distribution generalization in TrOCR

How well does TrOCR work in practice? Our developer needs to know!

Something you will see a lot of in machine learning papers are tables filled with benchmarks. The tables in the [TrOCR official paper](https://arxiv.org/abs/2109.10282) include measures of performance on different benchmark datasets, including IAM, [a handwriting database assembled in 1999](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database). The base and large model variants (334M and 558M parameters) display **character error rates (CER) of 3.42 and 2.89, respectively**.

"Wow!", our developer thinks, "That's probably good enough for my notes app! Guess I can go ahead and deploy it".

## Think! 1

What are some reasons why the character error rate measured on IAM might be too optimistic?

## Coding activity 1: measuring out-of-distribution generalization

Our developer reads through the fine print in the paper and realizes that the TrOCR is both *trained* on IAM and *tested* on IAM, on a different set of subjects. To be clear, the train and test splits *are* distinct; but samples come from the same underlying distribution. Our developer realizes that the reported error rates might be too optimistic:

* IAM was recorded on a tablet. Our developer wants to be able to recognize lines of text handwritten on paper.
* IAM is 25 years old. Maybe people write differently now compared to in the past. Do they even write in cursive anymore?
* The sentences in IAM are based on a widely published corpus. Maybe TrOCR has memorized that corpus.

The more the developer thinks about it, the more they realize that the paper is really estimating *in-distribution generalization*. However, what they care about is how well the model will work when it's deployed *in the wild*, which is closer to **out-of-distribution generalization**.

In this coding activity, you'll measure out-of-distribution generalization on a small subset of the CVL database:

> Kleber, F., Fiel, S., Diem, M., & Sablatnig, R. (2018). CVL Database - An Off-line Database for Writer Retrieval, Writer Identification and Word Spotting [Data set]. Zenodo. https://doi.org/10.5281/zenodo.1492267

Let's first have a look at this new out-of-distribution dataset.

In [None]:
# @title Run this cell to visualize dataset.
def get_images_and_transcripts(df, subject):
    df_ = df[df.subject == subject]
    transcripts = df_.transcript.values.tolist()

    # Load the corresponding images
    images = []
    for _, row in df_.iterrows():
        images.append(IMG.open(row.filename))

    return images, transcripts

def visualize_images_and_transcripts(images, transcripts):
    for img in images:
        display(img)

    for transcript in transcripts:
        print(transcript)

df = pd.read_csv('transcripts.csv')
df['filename'] = df.apply(lambda x: f"lines/{x.subject:04}-{x.line}.jpg", axis=1)
df

This is a small test set with 94 lines sampled from 10 different subjects. Let's have a look at the data from subject 54.

In [None]:
images, true_transcripts = get_images_and_transcripts(df, 52)
visualize_images_and_transcripts(images, true_transcripts)

The text is transcribed from a passage in the novel [Flatland by Edwin Abbott Abbott](https://en.wikipedia.org/wiki/Flatland). How well does the model recognize the text? Run this cell to find out.

In [None]:
def transcribe_images(all_images, model, processor):
    """
    Transcribe a batch of images using an OCR model.

    Args:
        all_images: a list of PIL images.
        model: the model to do image-to-token ids
        processor: the processor which maps token ids to text

    Returns:
        a list of the transcribed text.
    """
    pixel_values = processor(images=all_images, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values.to(device))
    decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return decoded_text

transcribed_text = transcribe_images(images, model, processor)
print(transcribed_text)

### Code exercise 1.1: calculate CER and WER

The model is not perfect but it performs far better than chance. Let's measure the character and word error rates on this subject's data. Fill in missing code to measure character and word error rates on this dataset.

In [None]:
def clean_string(input_string):
    """
    Clean string prior to comparison

    Args:
        input_string (str): the input string

    Returns:
        (str) a cleaned string, lowercase, alphabetical characters only, no double spaces
    """

    # Convert all characters to lowercase
    lowercase_string = input_string.lower()

    # Remove non-alphabetic characters
    alpha_string = re.sub(r'[^a-z\s]', '', lowercase_string)

    # Remove double spaces and start and end spaces
    return re.sub(r'\s+', ' ', alpha_string).strip()


def calculate_mismatch(estimated_text, reference_text):
    """
    Calculate mismatch (character and word error rates) between estimated and true text.

    Args:
        estimated_text: a list of strings
        reference_text: a list of strings

    Returns:
        A tuple, (CER and WER)
    """
    # Lowercase the text and remove special characters for the comparison
    estimated_text = [clean_string(x) for x in estimated_text]
    reference_text = [clean_string(x) for x in reference_text]

    ############################################################
    # Fill in this code to calculate character error rate and word error rate.
    # Hint: have a look at the torchmetrics documentation for the proper
    # metrics.
    #
    # https://lightning.ai/docs/torchmetrics/stable/
    raise NotImplementedError("Student has to fill in these lines")
    ############################################################

    # Calculate the character error rate and word error rates. They should be
    # raw floats, not tensors.
    cer = ...
    wer = ...
    return (cer, wer)

In [None]:
# to_remove solution

def clean_string(input_string):
    """
    Clean string prior to comparison

    Args:
        input_string (str): the input string

    Returns:
        (str) a cleaned string, lowercase, alphabetical characters only, no double spaces
    """

    # Convert all characters to lowercase
    lowercase_string = input_string.lower()

    # Remove non-alphabetic characters
    alpha_string = re.sub(r'[^a-z\s]', '', lowercase_string)

    # Remove double spaces and start and end spaces
    return re.sub(r'\s+', ' ', alpha_string).strip()


def calculate_mismatch(estimated_text, reference_text):
    """
    Calculate mismatch (character and word error rates) between estimated and true text.

    Args:
        estimated_text: a list of strings
        reference_text: a list of strings

    Returns:
        A tuple, (CER and WER)
    """
    # Lowercase the text and remove special characters for the comparison
    estimated_text = [clean_string(x) for x in estimated_text]
    reference_text = [clean_string(x) for x in reference_text]

    # Calculate the character error rate and word error rates. They should be
    # raw floats, not tensors.
    cer = fm.char_error_rate(estimated_text, reference_text).item()
    wer = fm.word_error_rate(estimated_text, reference_text).item()
    return (cer, wer)

In [None]:
cer, wer = calculate_mismatch(transcribed_text, true_transcripts)
assert isinstance(cer, float)
cer, wer

For this particular subject, the character error rate is 3.3%, while the word error rate is 10%. Not bad, and in line with the results in the paper.

### Code exercise 1.2: calculate CER and WER across all subjects

Let's measure the same metric, this time across all subjects.

In [None]:
def calculate_all_mismatch(df, model, processor):
    """
    Calculate CER and WER for all subjects in a dataset

    Args:
        df: a dataframe containing information about images and transcripts
        model: an image-to-text model
        processor: a processor object

    Returns:
        a list of dictionaries containing a per-subject breakdown of the
        results
    """
    subjects = df.subject.unique().tolist()

    results = []

    # Calculate CER and WER for all subjects
    for subject in tqdm.tqdm(subjects):
        ############################################################
        # Fill in the section to calculate the cer and wer for a
        # single subject. Look up at other sections to see how it's
        # done.
        raise NotImplementedError("Student exercise")
        ############################################################

        # Load images and labels for a given subject
        images, true_transcripts = ...

        # Transcribe the images to text
        transcribed_text = ...

        # Calculate the CER and WER
        cer, wer = ...

        results.append({
            'subject': subject,
            'cer': cer,
            'wer': wer,
        })
    return results

In [None]:
# to_remove solution
def calculate_all_mismatch(df, model, processor):
    """
    Calculate CER and WER for all subjects in a dataset

    Args:
        df: a dataframe containing information about images and transcripts
        model: an image-to-text model
        processor: a processor object

    Returns:
        a list of dictionaries containing a per-subject breakdown of the
        results
    """
    subjects = df.subject.unique().tolist()

    results = []

    # Calculate CER and WER for all subjects
    for subject in tqdm.tqdm(subjects):
        # Load images and labels for a given subject
        images, true_transcripts = get_images_and_transcripts(df, subject)

        # Transcribe the images to text
        transcribed_text = transcribe_images(images, model, processor)

        # Calculate the CER and WER
        cer, wer = calculate_mismatch(transcribed_text, true_transcripts)

        results.append({
            'subject': subject,
            'cer': cer,
            'wer': wer,
        })
    return results

In [None]:
results = calculate_all_mismatch(df, model, processor)
df_results = pd.DataFrame(results)
df_results

Not all subjects are as easy to transcribe as subject 52! Let's check out subject 57, who has high CER and WER.

In [None]:
print("A subject that's harder to read")
images, true_transcripts = get_images_and_transcripts(df, 57)
visualize_images_and_transcripts(images, true_transcripts)

Indeed, this text seems harder to read.

### Code exercise 1.3: measure OOD generalization

What we've done thus far is to measure the empirical loss–the character error rate–for each subject. The empirical loss is defined as:

$$R^e(\theta) = \mathbb{E}^e[ L(y, f(x, \theta)) ] $$

Here:

* The environment $e$ is the training distribution
* $R^e(\theta)$ is the empirical risk in an environment
* $\theta$ are the learned parameters of the TrOCR model
* $x$ is the conditioning data, that is, the images
* $f$ is the function approximated by the TrOCR model, which maps images to probabilities of certain tokens
* $L$ is the loss (or metric–not necessarily differentiable) for a single line of text, the character error rate (CER)
* $\mathbb{E}^e$ is the expectation taken over all the samples

A single environment $e$ corresponds to a single subject. The out-of-distribution generalization is instead given by:

$$R^{OOD} = \max_{e \in \mathcal{E}_{all}} R^e(\theta) $$

It's the worst-case empirical loss over the out-of-distribution environments ${e \in \mathcal{E}_{all}}$ we wish to deploy on. In other words, the character error rate for the subject with the most difficult-to-read handwriting.

Intuitively, our AI developer's vision of robustness might be: my note transcription app is robust and generalizes if it works well even when someone has illegible handwriting. The app is only as good as how well it works in the worst-case scenario. Let's measure that.

In [None]:
def calculate_mean_max_cer(df_results):
    """
    Calculate the mean character-error-rate across subjects as
    well as the maximum (that is, the OOD risk).

    Args:
        df_results: a dataframe containing results

    Returns:
        A tuple, (mean_cer, max_cer)
    """
    ############################################################
    # Fill in the section to calculate the mean and max cer
    # across subjects.
    raise NotImplementedError("Student exercise")
    ############################################################

    # Calculate the mean CER across test subjects.
    mean_subjects = ...

    # Calculate the max CER across test subjects.
    max_subjects = ...
    return mean_subjects, max_subjects

In [None]:
# to_remove solution
def calculate_mean_max_cer(df_results):
    """
    Calculate the mean character-error-rate across subjects as
    well as the maximum (that is, the OOD risk).

    Args:
        df_results: a dataframe containing results

    Returns:
        A tuple, (mean_cer, max_cer)
    """
    # Calculate the mean CER across test subjects.
    mean_subjects = df_results.cer.mean()

    # Calculate the max CER across test subjects.
    max_subjects = df_results.cer.max()
    return mean_subjects, max_subjects

In [None]:
mean_subjects, max_subjects = calculate_mean_max_cer(df_results)
mean_subjects, max_subjects

We see that:

* when measured on this (admittedly small) out-of-distribution dataset, the average character error rate is about 5.8%, larger than the 3.4% reported for IAM
* the out-of-distribution character error rate is 12%

Whether that's good enough for our AI developer depends on the use case.

## Discussion

Numbers in tables filled with benchmarks don't tell the whole story: often, we care about OOD robustness. Our developer benchmarked the TrOCR model for their use case and found a worst-case character error rate above 10%. Whether or not that's acceptable is a judgment call, and it's not the only metric the developer might care about. They might also need to meet other constraints:

- Memory, FLOPs, latency, cost of inference: the deployment environment might not be able to support very large-scale models because of memory or compute constraints, or those would run too slowly for the use case. Cloud inference might not be practical with limited internet access.
- SWaP-C: if the model is embodied in a physical device, the Size, Weight, Power and Cost of that device will ultimately be important.
- Latency of development: a bespoke model developed from scratch might take a long time to develop; our busy developer might prefer to adapt a pretrained, sub-optimal architecture than using a custom architecture
- Cost of upkeep: machine learning systems can be notoriously difficult to keep running. Our developer might prefer to use a suboptimal system managed by somebody else than taking on the burden of dealing with the upkeep themselves.

Our intrepid developer wants to ship this app soon! They decide on a strategy: the model is good enough to get started. They'll deploy the model as is, but they'll have an option in the app to report errors. They'll then label *those* errors and fine-tune the model. Before that though, they want to understand what's inside the model.

# Section 3: Dissecting TrOCR

TrOCR (transformer-based optical character recognition) is a model that performs printed optical character recognition and handwriting transcription on the basis of two transformers. But what's inside of it?

In [None]:
# @title Show image
# @markdown

display(Image(filename="trocr_architecture.png"))

TrOCR uses two transformers in an encoder-decoder architecture:

1. An encoder, a vision transformer (ViT), maps 16x16 patches of the image to individual tokens
2. A decoder, a text transformer, maps previously decoded text and the encoder hidden state to the next token in the sequence to be decoded. This is known as causal language modelling.

## Section 3.1: A recap of transformers

[We covered transformers in W2D5 of the DL course](https://deeplearning.neuromatch.io/tutorials/W2D5_AttentionAndTransformers/student/W2D5_Tutorial1.html). Let's quickly recap transformers. Transformers are a class of deep learning architectures that have become dominant in natural language processing (NLP) since their introduction in the paper "Attention is All You Need" by Vaswani et al. in 2017. Their success in natural language processing has led to their application across other domains, including computer vision, which is the case with TrOCR.

In [None]:
# @title Show image
# @markdown

display(Image(filename="transformer_one_layer.png"))



*Illustration from Alammar, J (2018). The Illustrated Transformer. Retrieved from https://jalammar.github.io/illustrated-transformer/*

Transformers are built on self-attention, allowing them to weigh the importance of different parts of the input data differently. This has proven useful for tasks that require an understanding of context, such as language translation, text summarization, and, as we will see, optical character recognition. Some key components of transformers are:

- Tokenization: the input sequence (e.g. sentence) is split into different components (e.g. word pieces). Each component, or token, is embedded into a fixed dimensional space. In natural language processing, tokenization is done via a lookup table: every word piece is mapped to a fixed-dimensional vector. [See W3D1 of the DL course for a refresher on tokenization](https://deeplearning.neuromatch.io/tutorials/W3D1_TimeSeriesAndNaturalLanguageProcessing/student/W3D1_Tutorial2.html?highlight=word2vec#tokenizers).

- Self-attention: A self-attention mechanism allows the tokens in the sequence to interact to form new representations. Specifically, queries and keys are derived from tokens; an inner product between queries and keys, followed by a softmax, forms the attention matrix. The attention matrix is multiplied by the value matrix to obtain a new representation.

- Positional encoding: Positional encoding is added to the input to give the model information about the position of each token within the sequence. Unlike RNNs or CNNs, transformers do not process data in order–without position encoding, they are permutation invariant. We'll dig deeper into what this implies in the section on the inductive biases of transformers.

- Layer Normalization and Residual Connections are used within the transformer architecture to stabilize the learning process and improve the model's ability to learn deep representations.

One of the key advantages of transformers over previous architectures is a high degree of parallelism, which allows one to train larger, more capable models. Let's inspect the training data of TrOCR.

## Section 3.2: The encoder and decoder

Let's dig in more specifically into the **encoder** inside of TrOCR. It's a visual transformer (ViT), an adaptation of transformers for problems in vision. It proceeds as follows:

1. It takes a raw image and resizes it to 384x384
2. It chops it up into 16x16 patches
3. It embeds each patch inside a fixed dimensional space, adding positional embeddings
4. It passes the patches through self-attention layers.
5. It ends up $577=(384/16)^2+1$ total embedded tokens. For the base model, the tokens have an embedding size of 768.

Let's see the structure of the encoder:

In [None]:
model.encoder

### Code exercise 3.1: Understanding the inputs and outputs of the decoder

Let's make sure we understand how the encoder operates by giving it a sample input and checking that its output is the expected shape.

In [None]:
def inspect_decoder(model):
    """
    Inspect decoder to verify that it processes inputs in the expected way.

    Args:
        model: the TrOCR model
    """
    ##################################################################
    # Feed the encoder an input and measure the output to understand
    # the role of the vision encoder.
    raise NotImplementedError("Student exercise")
    #
    ##################################################################
    # Create an empty tensor (batch size of 1) to feed it to the encoder.
    # Remember that images should have 3 channels and have size 384x384
    # Recall that images are fed in pytorch with tensors of shape
    # batch x channels x height x width
    single_input = ...

    # Run the input through the encoder.
    output = ...

    # Measure the number of hidden tokens which are the output of the encoder
    hidden_shape = output['last_hidden_state'].shape

    assert hidden_shape[0] == 1
    assert hidden_shape[1] == 577
    assert hidden_shape[2] == 768

In [None]:
# to_remove solution

def inspect_decoder(model):
    """
    Inspect decoder to verify that it processes inputs in the expected way.

    Args:
        model: the TrOCR model
    """
    # Create an empty tensor (batch size of 1) to feed it to the encoder.
    # Remember that images should have 3 channels and have size 384x384
    # Recall that images are fed in pytorch with tensors of shape
    # batch x channels x height x width
    single_input = torch.zeros(1, 3, 384, 384).to(device)

    # Run the input through the encoder.
    output = model.encoder(single_input)

    # Measure the number of hidden tokens which are the output of the encoder
    hidden_shape = output['last_hidden_state'].shape

    assert hidden_shape[0] == 1
    assert hidden_shape[1] == 577
    assert hidden_shape[2] == 768

In [None]:
inspect_decoder(model)

The vision transformer acts much like a conventional encoder transformer in sequence-to-sequence tasks: it maps the input sequence to a hidden representation. This hidden representation is then attended during decoding using cross-attention.

We can locate the cross-attention in the decoder, as its keys and values have dimensionality 768 (same as the encoder) and its queries are of dimension 1024 (like the rest of the decoder).

In [None]:
model.decoder

TL;DR: there's nothing magic going on: there are two relatively large-scale transformers which are wired in the conventional encoder-decoder architecture. The transformers themselves are generic and have relatively weak built-in inductive biases. What allows the model to generalize beyond its training data?

# Section 4: The magic in the data

It's straightforward to write down the encoder-decoder transformer used by TrOCR–it's conceptually quite similar to the original transformer as outlined in Vaswani et al. (2017). What makes the model tick (and potentially break) is a good training pipeline to ensure good OOD generalization. It's worth taking a look at the TrOCR paper to see the many different sources of data which are used to train the model:

1. [The encoder is pretrained on masked image modelling on ImageNet-22k](https://huggingface.co/docs/transformers/en/model_doc/beit)
2. [The decoder is pretrained on masked language modelling on 160GB of raw text](https://arxiv.org/abs/1907.11692)
3. The entire model is trained end-to-end on 648M text lines found in 2M PDF pages on the internet, with the fonts randomly swapped
4. The model is then fine-tuned end-to-end on the IAM handwriting dataset, with heavy augmentation

Let's look at a few of these pieces in turn.

### Coding exercise 4.1: the data in the decoder

In this section, we take a look at how much data is distilled in the decoder. Let's calculate how long it would take to write the same number of words that the model is trained on, as a human.

In [None]:
def calculate_writing_time(total_words, words_per_day, days_per_week, weeks_per_year, average_human_lifespan):
    """
    Calculate the time required to write a given number of words in lifetimes.

    Inputs:
    - total_words (int): total number of words to be written.
    - words_per_day (int): number of words written per day.
    - days_per_week (int): number of days dedicated to writing per week.
    - weeks_per_year (int): number of weeks dedicated to writing per year.
    - average_human_lifespan (int): average lifespan of a human in years.

    Outpus:
    - time_to_write_lifetimes (float): time to write the given words in lifetimes.
    """

    #################################################
    ## TODO for students: fill in the missing variables ##
    # Fill out function and remove
    raise NotImplementedError("Student exercise: fill in the missing variables")
    #################################################

    words_per_year = words_per_day * days_per_week * weeks_per_year

    # Calculate the time to write in years
    time_to_write_years = total_words / ...

    # Calculate the time to write in lifetimes
    time_to_write_lifetimes = time_to_write_years / average_human_lifespan

    return time_to_write_lifetimes

In [None]:
# to_remove solution

def calculate_writing_time(total_words, words_per_day, days_per_week, weeks_per_year, average_human_lifespan):
    """
    Calculate the time required to write a given number of words in lifetimes.

    Inputs:
    - total_words (int): total number of words to be written.
    - words_per_day (int): number of words written per day.
    - days_per_week (int): number of days dedicated to writing per week.
    - weeks_per_year (int): number of weeks dedicated to writing per year.
    - average_human_lifespan (int): average lifespan of a human in years.

    Outpus:
    - time_to_write_lifetimes (float): time to write the given words in lifetimes.
    """

    words_per_year = words_per_day * days_per_week * weeks_per_year

    # Calculate the time to write in years
    time_to_write_years = total_words / words_per_year

    # Calculate the time to write in lifetimes
    time_to_write_lifetimes = time_to_write_years / average_human_lifespan

    return time_to_write_lifetimes

Let's calculate how long it would take to generate all the text that the decoder is pre-trained on, as a human. The RoBERTa paper states that the model is pretrained on 160GB of text.

In [None]:
def bytes_to_words(total_bytes):
    """
    Find the approximate number of words corresponding to a certain number of bytes.

    Args:
        total_bytes (int): the total number of bytes

    Returns:
        The approximate number of words in a total number of bytes
    """
    ##################################################################
    # Use internet skills to find the approximate length of an English
    # word. Using the fact that 1 character = 1 byte for English
    # characters, return the approximate number of words in a given
    # number of bytes.
    raise NotImplementedError("Student exercise")
    ##################################################################
    return ...

In [None]:
# to_remove solution
def bytes_to_words(total_bytes):
    """
    Find the approximate number of words corresponding to a certain number of bytes.

    Args:
        total_bytes (int): the total number of bytes

    Returns:
        The approximate number of words in a total number of bytes
    """
    return total_bytes / 5.8

How many lifetimes would it take to write all this text?

In [None]:
bytes_to_words(160 * 1024 * 1024 * 1024)

In [None]:
# Example values
total_bytes = 160 * 1024 * 1024 * 1024
total_words = bytes_to_words(total_bytes)
words_per_day = 1500
days_per_week = 6
weeks_per_year = 50
# From age 20 to age 80
average_human_lifespan = 60

# Test the function
time_to_write_lifetimes_roberta = calculate_writing_time(
    total_words,
    words_per_day,
    days_per_week,
    weeks_per_year,
    average_human_lifespan
)

# Print the result
print(f"Time to write {total_words:.0f} words in lifetimes: {time_to_write_lifetimes_roberta:.0f} lifetimes")

It's easy to take for granted modern AI, but it's truly remarkable how much data is distilled even in 5-year-old models. 1097 lifetimes is a long time, but RoBERTa is pretty tiny by modern standards. Let's see how much data is distilled in LLaMa 2, trained on 2 trillion tokens.

In [None]:
#Exploring Llama 2
total_tokens_llama2 = 2e12
total_words_llama2 = 2e12 / 1.5 #assuming 1.5 words per token
time_to_write_lifetimes_llama = calculate_writing_time(total_words_llama2, words_per_day, days_per_week, weeks_per_year, average_human_lifespan)
print(f"Time to write {total_words_llama2:.0f} words in lifetimes: {time_to_write_lifetimes_llama:.0f} lifetimes")

Based on these considerations, we can see that an astounding amount of data is distilled in the model's decoder. Despite this, it's still fairly small compared to the state-of-the-art. This must mean that:

1. It is dependent on its learned language model
2. It is limited by its learned language model

We see these dynamics at play when we test the model on rare strings like "Neuromatch". RoBERTa was trained in 2019, before Neuromatch existed. It can express out-of-vocabulary words which follow English orthography rules but are not in its training set. Yet, it wouldn't give a high likelihood to the word Neuromatch compared to other reasonable alternatives. A more sophisticated language model might be able to use subtle context to figure out that the word Neuromatch has a high likelihood.

In [None]:
# @title Submit your feedback
#content_review(f"{feedback_prefix}_Calculate_Writing_Time_Exercise")

## Section 4.2: Generalization via augmentation

Another important part of the training recipe for this model is the use of multiple augmentations of the data. When data is not abundant, this can improve generalization. Thus, we take an expressive model with few built-in inductive biases, and through demonstrations, let it learn invariances and equivariances in the data, encouraging generalization.

By applying various transformations to images and displaying the results, you can visually understand how augmentation works and its impact on model performance. Let's look at parts of the TrOCR recipe.

Let's start with loading and visualizing our chosen image.

In [None]:
# Usage
display(Image(filename="neuroai_hello_world.png"))

Now, we will apply a few transformations to this image. You can play around with the input values!

In [None]:
# Convert PIL Image to Tensor
image = IMG.open("neuroai_hello_world.png")
image = transforms.ToTensor()(image)

# Define each transformation separately
# RandomAffine: applies rotations, translations, scaling. Here, rotates by up to ±15 degrees,
affine = transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1))

# ElasticTransform: applies elastic distortions to the image. The 'alpha' parameter controls
# the intensity of the distortion.
elastic = transforms.ElasticTransform(alpha=25.0)

# RandomPerspective: applies random perspective transformations with a specified distortion scale.
perspective = transforms.RandomPerspective(distortion_scale=0.2, p=1.0)

# RandomErasing: randomly erases a rectangle area in the image.
erasing = transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random', inplace=False)

# GaussianBlur: applies gaussian blur with specified kernel size and sigma range.
gaussian_blur = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.8, 5))

Let's now combine them in a single list and display the images.

In [None]:
# A list of all transformations for iteration
transformations = [affine, elastic, perspective, erasing, gaussian_blur]

# Display
display_transformed_images(image, transformations)

The transformations applied to the model include:

1. Original: the baseline image without any modifications.
2. RandomAffine: applies random affine transformations to the image, which include translation, scaling, rotation, and shearing. This helps the model become invariant to such transformations in the input data.
3. ElasticTransform: introduces random elastic deformations, simulating non-linear transformations that might occur naturally. It is useful for tasks where we expect such distortions, like medical image analysis.
4. RandomPerspective: changes the perspective from which the image is viewed, simulating the effect of viewing the object from different angles.
5. RandomErasing: randomly removes parts of the image and fills it with some arbitrary pixel values. It can make the model robust against occlusions in the input data.
6. GaussianBlur: applies a Gaussian blur to the image, smoothing it. This can help the model be better with out-of-focus images.

All of these augmentations, which are part of this models' training recipe, help prevent overfitting and improving the generalization of the model to new, unseen images. We can compose these to create new challenging training images:

In [None]:
# Combine all the transformations
all_transforms = transforms.Compose([
    affine,
    elastic,
    perspective,
    erasing,
    gaussian_blur
])

# Apply combined transformation
augmented_image_tensor = all_transforms(image)

display_original_and_transformed_images(image, augmented_image_tensor)

Now, all those trasnformations are being applied simultaneously.

In [None]:
# @title Submit your feedback
#content_review(f"{feedback_prefix}_Augmentation_Exercise")

## Section 4.3: Generalization via synthetic data

When augmentation is not enough, we can further improve generalization by training on synthetic data. This allows us to stretch our data even further. Data augmentation creates variations of existing data without changing its inherent properties, while synthetic data generation creates entirely new data that mimics the characteristics of real data.

As it turns out, generating new text is tractable–text can be rendered in a wide range of cursive fonts to simulate real data. Here, we'll showcase this idea by defining strings and create a generator to generate a synthetic version of the input data.

In [None]:
# Define strings
strings = ['Hello world', 'This is the first tutorial', 'For Neuromatch NeuroAI']

# Specify font path
font_path = "DancingScript-VariableFont_wght.ttf"  # Ensure this path is correct

# Example usage
strings = ['Hello world', 'This is the first tutorial', 'For Neuromatch NeuroAI']
font_path = "DancingScript-VariableFont_wght.ttf"  # Ensure this path is correct

# Create a generator with the specified parameters
generator = image_generator(strings, font_path, space_width=2, skewing_angle=3)

i = 1
for img in generator:
  plt.imshow(img, cmap='gray')
  plt.title(f"Example {i}")
  plt.axis('off')
  plt.show()
  i += 1

### Discussion point

What does this type of synthetic data capture that wouldn’t be easy to capture through data augmentation?

In [None]:
# @title Submit your feedback
#content_review(f"{feedback_prefix}_Discussion_Synthetic_Exercise")

### Interactive demo 4.1: Generating handwriting style data

We can take this idea further and generate handwriting style data. We will use an embedded calligrapher.ai model to generate new snippets of writing data.

In [None]:
IFrame("https://www.calligrapher.ai/", width=800, height=600)

In [None]:
# @title Submit your feedback
#content_review(f"{feedback_prefix}_Generate_Handwriting_Exercise")

# Conclusion

We train models to minimize a loss function. Oftentimes, however, what we care about is something different, like how well the model will generalize when it's deployed. Our intrepid developer got a rude awakening in comparing the OOD robustness of the model to its empirical loss on the train set: the character error rate was several times larger than expected. Motivated by other factors, like engineering complexity, our developer decided to move forward and deploy a handwriting transcription system, hoping it could be fine-tuned based on user data later.

There's a lot that goes into the training of robust AI models that generalize well. Generic high-capacity models with weak inductive biases, like transformers, are trained on large-scale data. Pretraining, augmentations and synthetic data can all be part of the recipe for learning good inductive biases that might be hard to express mathematically. Because large-scale models can often require significant compute to train, in practice, models that have been trained for other purposes are adapted and re-used, preventing the need to learn from scratch. These models embody what's known as ["the bitter lesson"](http://www.incompleteideas.net/IncIdeas/BitterLesson.html): general methods that leverage computation are ultimately the most effective, and by a large margin.