In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F


In [2]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len = 256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [3]:
train_data = [
    ("The Milky Way galaxy contains billions of stars, each with its own planetary system.", "Astronomy"),
    ("Black holes are regions in space where the gravitational pull is so strong that nothing, not even light, can escape.", "Astronomy"),
    ("The Hubble Space Telescope has provided some of the most breathtaking images of the universe.", "Astronomy"),
    ("Astronomers use telescopes to observe celestial bodies and gather data about the cosmos.", "Astronomy"),
    ("The study of exoplanets has revealed that many of them may have conditions suitable for life.", "Astronomy"),
    ("Dark matter and dark energy are mysterious components that make up most of the universe, but we know very little about them.", "Astronomy"),
    ("The Andromeda Galaxy is on a collision course with the Milky Way and will eventually merge with it.", "Astronomy"),
    ("Comets are icy celestial bodies that develop bright tails when they approach the Sun.", "Astronomy"),
    ("The Big Bang theory is the prevailing explanation for the origin of the universe.", "Astronomy"),
    ("Neutron stars are the remnants of massive stars that have exploded in supernovae.", "Astronomy"),
    ("The Shawshank Redemption is widely considered one of the greatest films of all time.", "Movies"),
    ("Christopher Nolan is known for his mind-bending and visually stunning movies, such as Inception.", "Movies"),
    ("The Marvel Cinematic Universe has created a shared universe that connects multiple superhero films.", "Movies"),
    ("The use of special effects has revolutionized the way blockbuster movies are made.", "Movies"),
    ("Alfred Hitchcock's Psycho is a classic horror film that has influenced many directors.", "Movies"),
    ("The Academy Awards, or Oscars, celebrate the best in film every year.", "Movies"),
    ("Quentin Tarantino is famous for his unique style and dialogue-rich films like Pulp Fiction.", "Movies"),
    ("Animated movies like Toy Story have become beloved classics for both children and adults.", "Movies"),
    ("Star Wars has a massive fanbase and has had a significant impact on popular culture.", "Movies"),
    ("Documentaries can offer profound insights into real-life events and issues.", "Movies"),
    ("Gardening can be a therapeutic activity that helps reduce stress.", "Random"),
    ("The Eiffel Tower in Paris is one of the most famous landmarks in the world.", "Random"),
    ("Cooking new recipes from different cuisines can be an exciting culinary adventure.", "Random"),
    ("Chess is a game of strategy that has been enjoyed for centuries.", "Random"),
    ("Yoga promotes flexibility, strength, and mental clarity.", "Random"),
    ("Reading books can transport you to different worlds and expand your knowledge.", "Random"),
    ("Hiking in nature can be a great way to stay fit and enjoy the outdoors.", "Random"),
    ("Music has the power to evoke strong emotions and bring people together.", "Random"),
    ("Photography allows us to capture and preserve moments in time.", "Random"),
    ("Traveling to new places can broaden your horizons and introduce you to diverse cultures.", "Random"),

    # Astronomy Sentences
    ("The Hubble Space Telescope has greatly enhanced our understanding of the cosmos with its stunning images of distant galaxies.", "Astronomy"),
    ("Jupiter's moon Europa is believed to have a subsurface ocean that may harbor life.", "Astronomy"),
    ("The speed of light is approximately 299,792 kilometers per second, making it the fastest thing in the universe.", "Astronomy"),
    ("The Drake Equation is used to estimate the number of active, communicative extraterrestrial civilizations in the Milky Way galaxy.", "Astronomy"),
    ("The Fermi Paradox questions why, if there are so many stars with likely habitable planets, we have yet to find any signs of alien life.", "Astronomy"),
    ("The Carina Nebula is a large, complex area of bright and dark nebulae in the constellation Carina.", "Astronomy"),
    ("Gravitational waves are ripples in spacetime caused by some of the most violent and energetic processes in the universe.", "Astronomy"),
    ("The Oort Cloud is a spherical shell of icy objects that are believed to surround the solar system and be the source of many comets.", "Astronomy"),
    ("The Kármán line at an altitude of 100 kilometers above sea level is often used to define the boundary between Earth's atmosphere and outer space.", "Astronomy"),
    ("Astrobiology is the study of the origin, evolution, and distribution of life in the universe.", "Astronomy"),

    # Movies Sentences
    ("Schindler's List is a poignant film that tells the story of Oskar Schindler's efforts to save Jews during the Holocaust.", "Movies"),
    ("Woody Allen's films often explore complex relationships and human psychology.", "Movies"),
    ("The Lord of the Rings trilogy is celebrated for its epic storytelling and expansive world-building.", "Movies"),
    ("Musical films like La La Land combine storytelling with music and dance to create a unique cinematic experience.", "Movies"),
    ("Film noir is a genre characterized by its dark, cynical atmosphere and morally ambiguous characters.", "Movies"),
    ("Blade Runner is a seminal science fiction film that raised questions about the nature of humanity and artificial intelligence.", "Movies"),
    ("Animated feature films often take years to produce due to the complexity of animation.", "Movies"),
    ("Cinematography is the art of capturing visual images for film, often contributing significantly to a movie's mood and storytelling.", "Movies"),
    ("Cult films have dedicated fanbases and often achieve enduring popularity despite a lack of mainstream success.", "Movies"),
    ("Film festivals like Sundance and TIFF provide platforms for independent filmmakers to showcase their work.", "Movies"),

    # Random Sentences
    ("Podcasts have become a popular medium for storytelling, interviews, and educational content.", "Random"),
    ("The Great Wall of China stretches over 13,000 miles and is one of the most famous landmarks in the world.", "Random"),
    ("Crafting and DIY projects can provide a creative outlet and a sense of accomplishment.", "Random"),
    ("Robotics and automation are transforming industries from manufacturing to healthcare.", "Random"),
    ("The history of art includes movements like Impressionism, Surrealism, and Abstract Expressionism, each with distinct characteristics.", "Random"),
    ("Astrology is the belief that the alignment of stars and planets can affect human behavior and destiny.", "Random"),
    ("Cryptocurrency is a digital or virtual currency that uses cryptography for security and operates independently of a central bank.", "Random"),
    ("Bird watching can be a peaceful and educational hobby, allowing individuals to observe different bird species in their natural habitats.", "Random"),
    ("The culinary arts involve not just cooking but also the presentation and appreciation of food.", "Random"),
    ("Urban planning involves designing and regulating the use of space to create functional and sustainable cities.", "Random"),

    ("The James Webb Space Telescope is designed to observe the most distant objects in the universe.", "Astronomy"),
    ("Neptune's Great Dark Spot was a storm similar to Jupiter's Great Red Spot.", "Astronomy"),
    ("Asteroids are small rocky bodies that orbit the Sun, mostly found in the asteroid belt.", "Astronomy"),
    ("The Sun's core is the location of nuclear fusion, which powers the Sun.", "Astronomy"),
    ("The summer and winter solstices mark the longest and shortest days of the year.", "Astronomy"),
    ("Binary stars are systems in which two stars orbit their common center of mass.", "Astronomy"),
    ("The Small Magellanic Cloud is a dwarf galaxy near the Milky Way.", "Astronomy"),
    ("Solar flares are sudden eruptions of intense high-energy radiation from the Sun's surface.", "Astronomy"),
    ("The H-R diagram is a scatter plot of stars showing the relationship between their brightness and temperature.", "Astronomy"),
    ("The Kuiper Belt was named after Dutch-American astronomer Gerard Kuiper.", "Astronomy"),

    ("Digital photography has made it easy to capture and share moments instantly.", "Random"),
    ("Chess grandmasters often spend years studying and practicing the game.", "Random"),
    ("The health benefits of regular exercise include improved cardiovascular health and mental well-being.", "Random"),
    ("Baking can be a relaxing and rewarding hobby that also produces delicious results.", "Random"),
    ("Many people enjoy reading mystery novels for the suspense and intrigue.", "Random"),
    ("Genealogy research helps individuals trace their family history and heritage.", "Random"),
    ("Camping allows people to disconnect from technology and reconnect with nature.", "Random"),
    ("Podcasts have surged in popularity as a way to learn and be entertained on the go.", "Random"),
    ("Volunteering is a meaningful way to give back to your community and make a difference.", "Random"),
    ("Journaling can be a therapeutic practice that helps you reflect on your thoughts and experiences.", "Random"),

    ("Tim Burton is known for his gothic and eccentric film style.", "Movies"),
    ("'The Dark Knight' features one of the most iconic performances by Heath Ledger as the Joker.", "Movies"),
    ("'Toy Story' is notable for being the first entirely computer-animated feature film.", "Movies"),
    ("'The Silence of the Lambs' is a thriller that features the well-known character Hannibal Lecter.", "Movies"),
    ("'Jurassic Park' amazed audiences with its realistic depiction of dinosaurs.", "Movies"),
    ("Django Unchained' is a western film directed by Quentin Tarantino.", "Movies"),
    ("'The Big Lebowski' is a cult classic known for its quirky characters and unique plot.", "Movies"),
    ("'Gladiator' is an epic historical drama directed by Ridley Scott.", "Movies"),
    ("'Goodfellas' is a crime film based on the true story of a mobster.", "Movies"),
    ("Pixar's 'Inside Out' creatively explores the emotions of a young girl.", "Movies"),
]

In [4]:
val_data = [
    # Astronomy Sentences
    ("The Orion Nebula is one of the brightest nebulae and is visible to the naked eye in the night sky.", "Astronomy"),
    ("A light-year is a unit of astronomical distance equivalent to the distance that light travels in one year.", "Astronomy"),
    ("The Sun is approximately 4.6 billion years old and is about halfway through its life cycle.", "Astronomy"),
    ("The rings of Saturn are composed mostly of ice particles with a smaller amount of rocky debris and dust.", "Astronomy"),
    ("The Large Magellanic Cloud is a satellite galaxy of the Milky Way and contains the Tarantula Nebula.", "Astronomy"),
    ("Exoplanets are planets that orbit a star outside the solar system.", "Astronomy"),
    ("The study of variable stars helps astronomers understand stellar evolution and the properties of distant galaxies.", "Astronomy"),
    ("The Event Horizon Telescope captured the first-ever image of a black hole in the galaxy M87.", "Astronomy"),
    ("Planetary nebulae are the remnants of certain types of stars that have shed their outer layers.", "Astronomy"),
    ("The phenomenon of redshift occurs when the light from distant galaxies is stretched to longer wavelengths.", "Astronomy"),

    # Movies Sentences
    ("The cinematic techniques pioneered by Alfred Hitchcock are still studied in film schools today.", "Movies"),
    ("The Matrix introduced audiences to groundbreaking special effects and philosophical themes.", "Movies"),
    ("Independent films often tackle unique and challenging subjects that mainstream movies avoid.", "Movies"),
    ("Citizen Kane is frequently cited as one of the greatest films ever made, known for its innovative storytelling.", "Movies"),
    ("The role of women in cinema has evolved significantly over the past century.", "Movies"),
    ("Documentaries like March of the Penguins offer viewers a glimpse into the lives of animals in their natural habitats.", "Movies"),
    ("The script of Casablanca is renowned for its memorable lines and complex characters.", "Movies"),
    ("Avatar broke box office records with its stunning visual effects and immersive world-building.", "Movies"),
    ("Silent films from the early 20th century laid the foundation for modern filmmaking techniques.", "Movies"),
    ("The horror genre often reflects the societal fears and anxieties of the time in which the films are made.", "Movies"),

    # Random Sentences
    ("The art of calligraphy is a beautiful and meditative form of handwriting.", "Random"),
    ("The history of ancient civilizations, such as Egypt and Mesopotamia, offers insights into the development of human society.", "Random"),
    ("Traveling by train can be a scenic and relaxing way to see the countryside.", "Random"),
    ("Baking bread from scratch is a rewarding culinary experience that fills your home with delightful aromas.", "Random"),
    ("The sport of rock climbing requires both physical strength and mental focus.", "Random"),
    ("Learning a new language can broaden your cultural understanding and open up new opportunities.", "Random"),
    ("The study of climate change is critical for understanding its impact on ecosystems and human societies.", "Random"),
    ("Meditation and mindfulness practices have been shown to reduce stress and improve overall well-being.", "Random"),
    ("Gardening can teach patience and provide a sense of accomplishment as plants grow and thrive.", "Random"),
    ("The invention of the printing press by Johannes Gutenberg revolutionized the way information was disseminated.", "Random"),

    # Astronomy Sentences
    ("The Cassini spacecraft provided detailed images of Saturn and its rings.", "Astronomy"),
    ("The moons of Uranus have unique, inclined orbits that intrigue astronomers.", "Astronomy"),
    ("The Crab Nebula is the remnant of a supernova observed in 1054 AD.", "Astronomy"),
    ("Pulsars are highly magnetized, rotating neutron stars that emit beams of electromagnetic radiation.", "Astronomy"),
    ("Messier 31, also known as the Andromeda Galaxy, is the nearest spiral galaxy to the Milky Way.", "Astronomy"),
    ("Galileo Galilei improved the telescope and made many essential astronomical observations.", "Astronomy"),
    ("A solar eclipse occurs when the Moon passes between the Earth and the Sun, blocking the Sun partially or entirely.", "Astronomy"),
    ("The Voyager missions have traveled beyond the solar system, sending data from interstellar space.", "Astronomy"),
    ("Cosmic microwave background radiation is the afterglow of the Big Bang, detected in all directions in space.", "Astronomy"),
    ("The Kuiper Belt contains many dwarf planets, including Pluto.", "Astronomy"),

    # Movies Sentences
    ("Al Pacino's portrayal of Michael Corleone in 'The Godfather' is considered iconic.", "Movies"),
    ("'Forrest Gump' follows the life of a man with a low IQ who experiences extraordinary events.", "Movies"),
    ("Stanley Kubrick directed classic films like '2001: A Space Odyssey' and 'The Shining'.", "Movies"),
    ("'The Avengers' brought together numerous Marvel superheroes in one film.", "Movies"),
    ("Studio Ghibli is renowned for its beautifully animated films like 'Spirited Away'.", "Movies"),
    ("Film critics often analyze movies for their thematic depth and technical craftsmanship.", "Movies"),
    ("'Saw' is a horror film series known for its elaborate traps and psychological thrills.", "Movies"),
    ("Biographical films, or biopics, depict the life of a real person with varying degrees of accuracy.", "Movies"),
    ("Costume design plays a significant role in establishing the setting and characters in a film.", "Movies"),
    ("'The Social Network' chronicled the founding of Facebook and the controversies around it.", "Movies"),

    # Random Sentences
    ("The Olympics bring together athletes from around the world to compete in various sports.", "Random"),
    ("Comic books have created an entire subculture with dedicated enthusiasts and conventions.", "Random"),
    ("Fitness apps can help you track workouts, monitor progress, and stay motivated.", "Random"),
    ("Playing a musical instrument, like the piano or guitar, can be a fulfilling hobby.", "Random"),
    ("E-commerce has revolutionized the way people shop, making it possible to buy almost anything online.", "Random"),
    ("The architecture of ancient Rome includes famous structures like the Colosseum and the Pantheon.", "Random"),
    ("Mindfulness meditation encourages staying present and aware without judgment.", "Random"),
    ("Sustainable living practices include reducing waste, recycling, and using renewable energy sources.", "Random"),
    ("Gardening can improve mental health, reduce stress, and provide fresh produce.", "Random"),
    ("Adventure travel allows people to explore new environments and engage in thrilling activities.", "Random")
]

In [5]:
test_data = [
    # Astronomy Sentences
    ("The James Webb Space Telescope is expected to be the successor to the Hubble Space Telescope.", "Astronomy"),
    ("Supernovae are powerful and luminous stellar explosions that occur at the end of a star's lifecycle.", "Astronomy"),
    ("The observable universe is thought to be about 93 billion light-years in diameter.", "Astronomy"),
    ("Quasars are extremely luminous active galactic nuclei powered by supermassive black holes.", "Astronomy"),
    ("The Kuiper Belt is a region of the solar system beyond Neptune that contains many small icy bodies.", "Astronomy"),
    ("The term 'light-year' refers to the distance that light travels in one year, approximately 5.88 trillion miles.", "Astronomy"),
    ("The surface of Mars is covered with iron oxide, giving it its distinctive red color.", "Astronomy"),
    ("The first human-made object to reach space was the Soviet Union's Sputnik satellite in 1957.", "Astronomy"),
    ("Venus is often called Earth's 'sister planet' because of their similar size and composition.", "Astronomy"),
    ("The Great Red Spot on Jupiter is a giant storm that has been raging for at least 400 years.", "Astronomy"),

    # Movies Sentences
    ("The Godfather is often regarded as one of the greatest films in cinema history.", "Movies"),
    ("Alfred Hitchcock was known as the 'Master of Suspense' for his thrilling movies.", "Movies"),
    ("Directors like Steven Spielberg have revolutionized the film industry with blockbuster hits like 'Jaws'.", "Movies"),
    ("'Parasite' made history by becoming the first non-English language film to win the Oscar for Best Picture.", "Movies"),
    ("Bollywood, the Hindi-language film industry based in Mumbai, India, produces hundreds of films each year.", "Movies"),
    ("Pixar Animation Studios is known for creating critically acclaimed animated films like 'Finding Nemo'.", "Movies"),
    ("The Berlin International Film Festival is one of the world's leading film festivals.", "Movies"),
    ("'Inglourious Basterds' is a war film written and directed by Quentin Tarantino.", "Movies"),
    ("Special effects in movies have advanced significantly with the advent of computer-generated imagery (CGI).", "Movies"),
    ("The Cannes Film Festival is held annually in Cannes, France, and is one of the most prestigious film festivals in the world.", "Movies"),

    # Random Sentences
    ("A well-balanced diet includes a variety of fruits, vegetables, proteins, and whole grains.", "Random"),
    ("The internet has transformed the way we communicate, access information, and entertain ourselves.", "Random"),
    ("Painting can be a relaxing and creative outlet for self-expression.", "Random"),
    ("Mountain biking is an adrenaline-pumping activity that combines fitness and a love for the outdoors.", "Random"),
    ("Board games, like chess and Monopoly, provide an excellent way for families and friends to bond.", "Random"),
    ("Classical music has the power to evoke strong emotions and has been enjoyed for centuries.", "Random"),
    ("Artificial intelligence is rapidly changing industries and daily life with advancements in automation and data analysis.", "Random"),
    ("Water conservation is critical for ensuring sustainable resources for future generations.", "Random"),
    ("Volunteering can be a rewarding experience that helps support communities and causes.", "Random"),
    ("Virtual reality technology is creating immersive experiences in gaming, education, and training.", "Random")
]


In [6]:
BERT_MODEL_NAME = 'bert-base-uncased'
NUM_CLASSES = 3  # Specify the number of classes
MAX_LEN = 32
BATCH_SIZE = 10
EPOCHS = 10
TOP_K = 2  # Number of top classes to recommend


In [7]:
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME, force_download=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [8]:
label_index = {"Astronomy": 1, "Movies": 2, }

def make_dataset(data):
    return TextDataset(
        texts = [d[0] for d in data],
        labels = [label_index[d[1]] if d[1] in label_index else 0 for d in data],
        # labels = [1 if d[1] == "Astronomy" else 0 for d in data],
        tokenizer = tokenizer,
    )

train_dataset = make_dataset(train_data)
test_dataset = make_dataset(test_data)
val_dataset = make_dataset(val_data)


In [9]:
train_dataset[11]

{'text': 'Christopher Nolan is known for his mind-bending and visually stunning movies, such as Inception.',
 'input_ids': tensor([  101,  5696, 13401,  2003,  2124,  2005,  2010,  2568,  1011, 14457,
          1998, 17453, 14726,  5691,  1010,  2107,  2004, 12149,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
         

In [10]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [11]:
from transformers import BertForSequenceClassification

classifier = BertForSequenceClassification.from_pretrained(
    BERT_MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

optimizer = AdamW(classifier.parameters(),
                  lr=1e-5,
                  eps=1e-8)

epochs = 5

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_loader)*epochs)


In [13]:
from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_index.items()}
    label_dict_inverse["Random"] = 0

    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

In [14]:
from tqdm import tqdm
import numpy as np
import random

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

device = "cpu"

def evaluate(dataloader_val):

    classifier.eval()

    loss_val_total = 0
    predictions, true_vals = [], []

    for batch in dataloader_val:

        inputs = {'input_ids':      batch["input_ids"],
                  'attention_mask': batch["attention_mask"],
                  'labels':         batch["label"],
                  }

        with torch.no_grad():
            outputs = classifier(**inputs)

        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)

    loss_val_avg = loss_val_total/len(dataloader_val)

    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)

    return loss_val_avg, predictions, true_vals

for epoch in tqdm(range(1, epochs+1)):

    classifier.train()

    loss_train_total = 0

    progress_bar = tqdm(train_loader, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        classifier.zero_grad()

        inputs = {'input_ids':      batch["input_ids"],
                  'attention_mask': batch["attention_mask"],
                  'labels':         batch["label"],
                  }

        outputs = classifier(**inputs)

        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})


    torch.save(classifier.state_dict(), f'../data/finetuned_BERT_epoch_{epoch}.model')

    tqdm.write(f'\nEpoch {epoch}')

    loss_train_avg = loss_train_total/len(train_loader)
    tqdm.write(f'Training loss: {loss_train_avg}')

    val_loss, predictions, true_vals = evaluate(val_loader)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')

  0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 1:   0%|          | 0/9 [00:03<?, ?it/s, training_loss=0.287][A
Epoch 1:  11%|█         | 1/9 [00:03<00:27,  3.50s/it, training_loss=0.287][A
Epoch 1:  11%|█         | 1/9 [00:06<00:27,  3.50s/it, training_loss=0.247][A
Epoch 1:  22%|██▏       | 2/9 [00:06<00:21,  3.08s/it, training_loss=0.247][A
Epoch 1:  22%|██▏       | 2/9 [00:09<00:21,  3.08s/it, training_loss=0.270][A
Epoch 1:  33%|███▎      | 3/9 [00:09<00:20,  3.37s/it, training_loss=0.270][A
Epoch 1:  33%|███▎      | 3/9 [00:12<00:20,  3.37s/it, training_loss=0.309][A
Epoch 1:  44%|████▍     | 4/9 [00:12<00:15,  3.11s/it, training_loss=0.309][A
Epoch 1:  44%|████▍     | 4/9 [00:15<00:15,  3.11s/it, training_loss=0.277][A
Epoch 1:  56%|█████▌    | 5/9 [00:15<00:11,  2.99s/it, training_loss=0.277][A
Epoch 1:  56%|█████▌    | 5/9 [00:18<00:11,  2.99s/it, training_loss=0.327][A
Epoch 1:  67%|██████▋   | 6/9 [00:18<00:08,  2.93s/


Epoch 1
Training loss: 1.1559863355424669


 20%|██        | 1/5 [00:30<02:02, 30.54s/it]

Validation loss: 1.1084368924299877
F1 Score (Weighted): 0.16666666666666666



Epoch 2:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 2:   0%|          | 0/9 [00:03<?, ?it/s, training_loss=0.252][A
Epoch 2:  11%|█         | 1/9 [00:03<00:28,  3.58s/it, training_loss=0.252][A
Epoch 2:  11%|█         | 1/9 [00:06<00:28,  3.58s/it, training_loss=0.267][A
Epoch 2:  22%|██▏       | 2/9 [00:06<00:22,  3.20s/it, training_loss=0.267][A
Epoch 2:  22%|██▏       | 2/9 [00:09<00:22,  3.20s/it, training_loss=0.304][A
Epoch 2:  33%|███▎      | 3/9 [00:09<00:18,  3.05s/it, training_loss=0.304][A
Epoch 2:  33%|███▎      | 3/9 [00:12<00:18,  3.05s/it, training_loss=0.261][A
Epoch 2:  44%|████▍     | 4/9 [00:12<00:15,  3.09s/it, training_loss=0.261][A
Epoch 2:  44%|████▍     | 4/9 [00:15<00:15,  3.09s/it, training_loss=0.278][A
Epoch 2:  56%|█████▌    | 5/9 [00:15<00:12,  3.07s/it, training_loss=0.278][A
Epoch 2:  56%|█████▌    | 5/9 [00:18<00:12,  3.07s/it, training_loss=0.229][A
Epoch 2:  67%|██████▋   | 6/9 [00:18<00:08,  2.96s/it, training_loss=0.229][A
Epoch 2: 


Epoch 2
Training loss: 1.0159194005860224


 40%|████      | 2/5 [01:01<01:32, 30.73s/it]

Validation loss: 0.9905066092809042
F1 Score (Weighted): 0.3682405874186696



Epoch 3:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 3:   0%|          | 0/9 [00:03<?, ?it/s, training_loss=0.219][A
Epoch 3:  11%|█         | 1/9 [00:03<00:28,  3.53s/it, training_loss=0.219][A
Epoch 3:  11%|█         | 1/9 [00:06<00:28,  3.53s/it, training_loss=0.251][A
Epoch 3:  22%|██▏       | 2/9 [00:06<00:21,  3.14s/it, training_loss=0.251][A
Epoch 3:  22%|██▏       | 2/9 [00:09<00:21,  3.14s/it, training_loss=0.212][A
Epoch 3:  33%|███▎      | 3/9 [00:09<00:17,  2.98s/it, training_loss=0.212][A
Epoch 3:  33%|███▎      | 3/9 [00:11<00:17,  2.98s/it, training_loss=0.218][A
Epoch 3:  44%|████▍     | 4/9 [00:11<00:14,  2.87s/it, training_loss=0.218][A
Epoch 3:  44%|████▍     | 4/9 [00:14<00:14,  2.87s/it, training_loss=0.209][A
Epoch 3:  56%|█████▌    | 5/9 [00:14<00:11,  2.81s/it, training_loss=0.209][A
Epoch 3:  56%|█████▌    | 5/9 [00:17<00:11,  2.81s/it, training_loss=0.231][A
Epoch 3:  67%|██████▋   | 6/9 [00:17<00:08,  2.81s/it, training_loss=0.231][A
Epoch 3: 


Epoch 3
Training loss: 0.8827226758003235


 60%|██████    | 3/5 [01:30<01:00, 30.00s/it]

Validation loss: 0.899446020523707
F1 Score (Weighted): 0.571571906354515



Epoch 4:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 4:   0%|          | 0/9 [00:03<?, ?it/s, training_loss=0.219][A
Epoch 4:  11%|█         | 1/9 [00:03<00:29,  3.67s/it, training_loss=0.219][A
Epoch 4:  11%|█         | 1/9 [00:06<00:29,  3.67s/it, training_loss=0.224][A
Epoch 4:  22%|██▏       | 2/9 [00:06<00:22,  3.20s/it, training_loss=0.224][A
Epoch 4:  22%|██▏       | 2/9 [00:09<00:22,  3.20s/it, training_loss=0.205][A
Epoch 4:  33%|███▎      | 3/9 [00:09<00:18,  3.03s/it, training_loss=0.205][A
Epoch 4:  33%|███▎      | 3/9 [00:12<00:18,  3.03s/it, training_loss=0.177][A
Epoch 4:  44%|████▍     | 4/9 [00:12<00:14,  2.93s/it, training_loss=0.177][A
Epoch 4:  44%|████▍     | 4/9 [00:14<00:14,  2.93s/it, training_loss=0.221][A
Epoch 4:  56%|█████▌    | 5/9 [00:14<00:11,  2.89s/it, training_loss=0.221][A
Epoch 4:  56%|█████▌    | 5/9 [00:17<00:11,  2.89s/it, training_loss=0.183][A
Epoch 4:  67%|██████▋   | 6/9 [00:17<00:08,  2.87s/it, training_loss=0.183][A
Epoch 4: 


Epoch 4
Training loss: 0.791619684961107


 80%|████████  | 4/5 [02:00<00:30, 30.01s/it]

Validation loss: 0.8358581165472666
F1 Score (Weighted): 0.7406675450153711



Epoch 5:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 5:   0%|          | 0/9 [00:03<?, ?it/s, training_loss=0.200][A
Epoch 5:  11%|█         | 1/9 [00:03<00:25,  3.20s/it, training_loss=0.200][A
Epoch 5:  11%|█         | 1/9 [00:06<00:25,  3.20s/it, training_loss=0.197][A
Epoch 5:  22%|██▏       | 2/9 [00:06<00:22,  3.15s/it, training_loss=0.197][A
Epoch 5:  22%|██▏       | 2/9 [00:09<00:22,  3.15s/it, training_loss=0.223][A
Epoch 5:  33%|███▎      | 3/9 [00:09<00:18,  3.07s/it, training_loss=0.223][A
Epoch 5:  33%|███▎      | 3/9 [00:12<00:18,  3.07s/it, training_loss=0.185][A
Epoch 5:  44%|████▍     | 4/9 [00:12<00:14,  3.00s/it, training_loss=0.185][A
Epoch 5:  44%|████▍     | 4/9 [00:14<00:14,  3.00s/it, training_loss=0.205][A
Epoch 5:  56%|█████▌    | 5/9 [00:14<00:11,  2.91s/it, training_loss=0.205][A
Epoch 5:  56%|█████▌    | 5/9 [00:17<00:11,  2.91s/it, training_loss=0.174][A
Epoch 5:  67%|██████▋   | 6/9 [00:17<00:08,  2.83s/it, training_loss=0.174][A
Epoch 5: 


Epoch 5
Training loss: 0.7651507125960456


100%|██████████| 5/5 [02:29<00:00, 29.99s/it]

Validation loss: 0.8158403635025024
F1 Score (Weighted): 0.7406675450153711





In [17]:
classifier.eval()
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        outputs = classifier(input_ids=input_ids, attention_mask=attention_mask)
        _, predicted = torch.max(outputs.logits.data, 1)
        for t, l, p in zip(batch["text"], labels.tolist(), predicted.tolist()):
            print(f"p={p}, a={l}: {t}")


p=1, a=1: The James Webb Space Telescope is expected to be the successor to the Hubble Space Telescope.
p=1, a=1: Supernovae are powerful and luminous stellar explosions that occur at the end of a star's lifecycle.
p=1, a=1: The observable universe is thought to be about 93 billion light-years in diameter.
p=1, a=1: Quasars are extremely luminous active galactic nuclei powered by supermassive black holes.
p=1, a=1: The Kuiper Belt is a region of the solar system beyond Neptune that contains many small icy bodies.
p=1, a=1: The term 'light-year' refers to the distance that light travels in one year, approximately 5.88 trillion miles.
p=1, a=1: The surface of Mars is covered with iron oxide, giving it its distinctive red color.
p=1, a=1: The first human-made object to reach space was the Soviet Union's Sputnik satellite in 1957.
p=1, a=1: Venus is often called Earth's 'sister planet' because of their similar size and composition.
p=1, a=1: The Great Red Spot on Jupiter is a giant storm t

# DIY FFN

In [10]:
class BertTextClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BertTextClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.fc = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)  # Adjust number of classes here
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # Output from [CLS] token
        normalized_output = F.normalize(pooled_output, p=2, dim=1)
        output = self.fc(normalized_output)
        return output


In [12]:
model = BertTextClassifier(BERT_MODEL_NAME, NUM_CLASSES)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [13]:
for epoch in range(EPOCHS):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}')


Epoch [1/10], Loss: 0.5235
Epoch [2/10], Loss: 0.6156
Epoch [3/10], Loss: 0.6937
Epoch [4/10], Loss: 0.4521
Epoch [5/10], Loss: 0.6787
Epoch [6/10], Loss: 0.7530
Epoch [7/10], Loss: 0.6885
Epoch [8/10], Loss: 0.5017
Epoch [9/10], Loss: 0.6793
Epoch [10/10], Loss: 0.4995


In [14]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Initialize EarlyStopping
early_stopping = EarlyStopping(patience=3, min_delta=0.0001)

In [15]:
# Training loop with early stopping
# for epoch in range(EPOCHS):
#     model.train()
#     running_loss = 0.0
#     for batch in train_loader:
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']
#         optimizer.zero_grad()
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item()
#
#     # Validation
#     model.eval()
#     val_loss = 0.0
#     with torch.no_grad():
#         for batch in val_loader:
#             input_ids = batch['input_ids']
#             attention_mask = batch['attention_mask']
#             labels = batch['label']
#             outputs = model(input_ids=input_ids, attention_mask=attention_mask)
#             loss = criterion(outputs, labels)
#             val_loss += loss.item()
#
#     avg_train_loss = running_loss / len(train_loader)
#     avg_val_loss = val_loss / len(val_loader)
#     print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
#
#     # Check early stopping
#     early_stopping(avg_val_loss)
#     if early_stopping.early_stop:
#         print("Early stopping")
#         break


In [16]:
model.eval()
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        outputs = model(input_ids, attention_mask)
        _, predicted = torch.max(outputs.data, 1)
        print(f'Texts: {batch["text"]}')
        print(f'Labels: {labels.tolist()}')
        print(f'Predicted: {predicted.tolist()}')


Texts: ['The James Webb Space Telescope is expected to be the successor to the Hubble Space Telescope.', "Supernovae are powerful and luminous stellar explosions that occur at the end of a star's lifecycle.", 'The observable universe is thought to be about 93 billion light-years in diameter.', 'Quasars are extremely luminous active galactic nuclei powered by supermassive black holes.', 'The Kuiper Belt is a region of the solar system beyond Neptune that contains many small icy bodies.', "The term 'light-year' refers to the distance that light travels in one year, approximately 5.88 trillion miles.", 'The surface of Mars is covered with iron oxide, giving it its distinctive red color.', "The first human-made object to reach space was the Soviet Union's Sputnik satellite in 1957.", "Venus is often called Earth's 'sister planet' because of their similar size and composition.", 'The Great Red Spot on Jupiter is a giant storm that has been raging for at least 400 years.']
Labels: [1, 1, 1, 

In [17]:
# TOP_K = 1
#
# model.eval()
# with torch.no_grad():
#     for batch in test_loader:
#         input_ids = batch['input_ids']
#         attention_mask = batch['attention_mask']
#         labels = batch['label']
#         outputs = model(input_ids, attention_mask)
#
#         # Get the top K predictions
#         top_k_values, top_k_indices = torch.topk(outputs, TOP_K, dim=1)
#
#         print(top_k_values)
#
#         for text, label, top_k_inds in zip(batch["text"], labels, top_k_indices):
#             print(f'Text: {text}')
#             print(f'Actual Label: {label.item()}')
#             print(f'Top {TOP_K} Predicted: {top_k_inds.tolist()}')