In [4]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import get_tokenizer
import torch.nn as nn
from tqdm import tqdm 
from datasets import load_dataset

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [6]:
torch.__version__, torch.version.cuda

('2.3.0+cu121', '12.1')

In [7]:
tokenizer = get_tokenizer("basic_english")


dataset = load_dataset("ag_news")
train_dataset = dataset["train"]
test_dataset = dataset["test"]


In [8]:
train_list = list(train_dataset)
test_list = list(test_dataset)
train_list[:5]

[{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  'label': 2},
 {'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.',
  'label': 2},
 {'text': "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.",
  'label': 2},
 {'text': 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Sat

In [9]:

train_x = list(map(lambda x : x["text"], train_list))
train_y = list(map(lambda x : x["label"], train_list))

In [10]:
x_train, x_val, y_train, y_val = train_test_split(train_x , train_y, test_size = 0.1, random_state=42)
# random state , is extremely important here, since the data used for training will determine the vocabulary size, and vocabulary_size will determine num_embeddings, which is part of model architecture.

In [11]:
x_test = list(map(lambda x : x["text"], test_list))
y_test = list(map(lambda x : x["label"], test_list))

In [12]:
vocab = build_vocab_from_iterator([tokenizer(s) for s in x_train],specials=["<unk>"]);
vocab.set_default_index(vocab["<unk>"])

In [13]:
def text_pipeline(s):
    tokenized_s = tokenizer(s)
    return torch.tensor(vocab(tokenized_s)).to(device)

In [14]:
def collate_fn(batch):
        labels, token_indices, offsets = [],[],[0]
        for text, label in batch:
            labels.append(label )
            # tip - its best to test everything on cpu first and then change then change device to "cuda"
            # why?, cause gpu usually hides the error that the pythonr intrepreter throws.
            # labels.append(label - 1) # this line will produce different error, when on gpu and on cpu.
            token_indices.extend( text )
            offsets.append(len(token_indices))
        return torch.tensor(labels).to(device), torch.tensor(token_indices).to(device), torch.tensor(offsets[:-1]).to(device)

In [15]:
x_train, x_val , x_test = [text_pipeline(s) for s in x_train], [text_pipeline(s) for s in x_val],[text_pipeline(s) for s in x_test]
# better to keep away , as much processing as we can, from loaders and the collate_fn

In [16]:
# BATCH_SIZE = 64 # small batches on GPU are inefficient.
BATCH_SIZE = 256
train_loader = DataLoader(list(zip(x_train, y_train)),batch_size = BATCH_SIZE,collate_fn = collate_fn )
test_loader = DataLoader(list(zip(x_test, y_test)), batch_size = BATCH_SIZE, collate_fn=collate_fn)
val_loader = DataLoader(list(zip(x_val, y_val)), batch_size=BATCH_SIZE, collate_fn=collate_fn);

In [17]:
class TextClassifier(nn.Module):
    
    def __init__(self, embedding_dim, num_embeddings):
        super().__init__();
        self.embedding = nn.EmbeddingBag(embedding_dim=embedding_dim, num_embeddings=num_embeddings)
        self.fc1 = nn.Linear(embedding_dim, 4);
    
    def forward(self,x, offsets):
        x = self.embedding(x, offsets)
        return self.fc1(x)


In [18]:
model = TextClassifier(embedding_dim=9, num_embeddings=len(vocab));
model.to(device)# every tensor should be on the same device , that includes the model parameters

TextClassifier(
  (embedding): EmbeddingBag(91226, 9, mode='mean')
  (fc1): Linear(in_features=9, out_features=4, bias=True)
)

In [19]:
def predict(s):
    token_indices = text_pipeline(s)
    labels = model.forward(token_indices, torch.tensor([0]).to(device))
    return labels.argmax()

In [20]:
def evaluate(data_loader):
    model.eval()
    with torch.no_grad():
        cumulative_accuracy = 0
        total = 0.0001
        for labels, token_indices, offsets in data_loader:
            total += 1
            pred_labels = model.forward(token_indices, offsets)
            cumulative_accuracy += (pred_labels.argmax(dim = 1) == labels).to(torch.float).mean()
        return cumulative_accuracy/total

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr= 0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.7)

In [22]:
model.load_state_dict(torch.load("my_model.pth"))

<All keys matched successfully>

In [23]:
epochs = 1000

model.train()
prev_accuracy = 0
for epoch in tqdm(range(epochs)):

    for labels, token_indices, offsets in train_loader:
        optimizer.zero_grad() # sets all of the gradients to zero.
        pred_labels = model.forward(token_indices, offsets)
        loss = criterion(pred_labels, labels)
        loss.backward() # creates gradients starting from loss.
        nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

    current_accuracy = evaluate(val_loader)
    if(current_accuracy > prev_accuracy):
        torch.save(model.state_dict(), "my_model.pth")
        prev_accuracy = current_accuracy
        print("update model on accuracy : ", prev_accuracy)
    scheduler.step()





  3%|▎         | 1/30 [00:18<08:47, 18.18s/it]

update model on accuracy :  tensor(0.8679)


100%|██████████| 30/30 [09:22<00:00, 18.74s/it]


In [24]:
evaluate(test_loader)

tensor(0.8601)

In [25]:
label_maps = {0:"World", 1:"Sports", 2:"Business", 3:"Sci/Tech"}

In [26]:
l = list(map( lambda x:(x["text"], label_maps[x["label"]]) , test_list[:] )) 
l[:5]

[("Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
  'Business'),
 ('The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\\privately funded suborbital space flight, has officially announced the first\\launch date for its manned rocket.',
  'Sci/Tech'),
 ('Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.',
  'Sci/Tech'),
 ("Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows wha

In [27]:
world_truths = [e for e in filter(lambda x: (x[1] == "World"), l)][:3]

In [28]:
sci_tech_truths = [e for e in filter(lambda x: (x[1] == "Sci/Tech"), l)][:3]

In [29]:
business_truths = [e for e in filter(lambda x: (x[1] == "Business"), l)][:3]

In [30]:
sports_truths = [e for e in filter(lambda x: (x[1] == "Sports"), l)][:3]

In [31]:
compare = lambda x : (x[0], x[1], label_maps[predict(x[0]).item()] )

In [32]:
[e for e in map(compare, sports_truths)]

[('Giddy Phelps Touches Gold for First Time Michael Phelps won the gold medal in the 400 individual medley and set a world record in a time of 4 minutes 8.26 seconds.',
  'Sports',
  'Sports'),
 ("Tougher rules won't soften Law's game FOXBOROUGH -- Looking at his ridiculously developed upper body, with huge biceps and hardly an ounce of fat, it's easy to see why Ty Law, arguably the best cornerback in football, chooses physical play over finesse. That's not to imply that he's lacking a finesse component, because he can shut down his side of the field much as Deion Sanders ...",
  'Sports',
  'Sports'),
 ("Shoppach doesn't appear ready to hit the next level With the weeks dwindling until Jason Varitek enters free agency, the Red Sox continue to carefully monitor Kelly Shoppach , their catcher of the future, in his climb toward the majors. The Sox like most of what they have seen at Triple A Pawtucket from Shoppach, though it remains highly uncertain whether he can make the adjustments a

In [33]:
[e for e in map(compare, world_truths)]

[("Sister of man who died in Vancouver police custody slams chief (Canadian Press) Canadian Press - VANCOUVER (CP) - The sister of a man who died after a violent confrontation with police has demanded the city's chief constable resign for defending the officer involved.",
  'World',
  'World'),
 ('Man Sought  #36;50M From McGreevey, Aides Say (AP) AP - The man who claims Gov. James E. McGreevey sexually harassed him was pushing for a cash settlement of up to  #36;50 million before the governor decided to announce that he was gay and had an extramarital affair, sources told The Associated Press.',
  'World',
  'World'),
 ('Explosions Echo Throughout Najaf NAJAF, Iraq - Explosions and gunfire rattled through the city of Najaf as U.S. troops in armored vehicles and tanks rolled back into the streets here Sunday, a day after the collapse of talks - and with them a temporary cease-fire - intended to end the fighting in this holy city...',
  'World',
  'World')]

In [34]:
[e for e in map(compare, business_truths)]

[("Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
  'Business',
  'Business'),
 ('Retailers Vie for Back-To-School Buyers (Reuters) Reuters - Apparel retailers are hoping their\\back-to-school fashions will make the grade among\\style-conscious teens and young adults this fall, but it could\\be a tough sell, with students and parents keeping a tighter\\hold on their wallets.',
  'Business',
  'Sports'),
 ("Dollar Briefly Hits 4-Wk Low Vs Euro  LONDON (Reuters) - The dollar dipped to a four-week low  against the euro on Monday before rising slightly on  profit-taking, but steep oil prices and weak U.S. data  continued to fan worries about the health of the world's  largest economy.",
  'Business',
  'Business')]

In [35]:
[e for e in map(compare, sci_tech_truths)]

[('The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\\privately funded suborbital space flight, has officially announced the first\\launch date for its manned rocket.',
  'Sci/Tech',
  'Sci/Tech'),
 ('Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.',
  'Sci/Tech',
  'Sci/Tech'),
 ("Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.",
  'Sci/Tech',
  'Sport