<a href="https://colab.research.google.com/github/maryamyazdi/transc/blob/text-classification/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers

In [10]:
import numpy as np
import pandas as pd
from sklearn import metrics
import nltk
import re
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertConfig

In [3]:
# Setting up GPU usage (if available)

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [6]:
init_df = pd.read_csv('stories.csv')
init_df

Unnamed: 0,body,topic
0,,['39822b5f-e37e-43e8-b997-7142fe55c3ea']
1,,['0d817400-3f5d-41e0-929c-c31fdbe75d31']
2,,['83a09c6b-5f2f-421f-ae50-b38acca7e008']
3,,['6fbf954a-03f9-4782-a65f-783271c9c447']
4,hello and welcome to BBC News a woman who gave...,"['83a09c6b-5f2f-421f-ae50-b38acca7e008', '9ff5..."
...,...,...
5176,News. More local help will soon be on the way....,"['9ff54ded-904b-4e0c-85ce-a3617f5cb913', '9632..."
5177,"with March 1, we start what is called Meteorol...",['9a06646a-e1df-4fca-888e-69658420556b']
5178,overseas. A massive Russian convoy is headed t...,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
5179,"And this morning, the National Hockey League s...","['9ff54ded-904b-4e0c-85ce-a3617f5cb913', 'b492..."


In [12]:
# Pre-processing the domain data (initial dataset)

nltk.download('stopwords')
from nltk.corpus import stopwords

# Remove empty body rows
df = init_df.copy()
df = df[df['body'].str.strip().astype(bool)]
df.reset_index(drop = True, inplace = True)

# Reformat 'topic' column to list +
topics = []
for index in df.index:
  df['topic'][index] = re.findall("[a-zA-Z0-9-]+",df['topic'][index])
  for id in df['topic'][index]:
    if id not in topics:
      topics.append(id)

# Removing stop words
stop = stopwords.words('english')
df['body'] = df['body'].apply(lambda x: ' '.join([word for word in x.lower().split() if word not in (stop)]))

df

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


Unnamed: 0,body,topic
0,hello welcome bbc news woman gave key evidence...,"[83a09c6b-5f2f-421f-ae50-b38acca7e008, 9ff54de..."
1,news north hollywood. 14 yearold girl found de...,[9ff54ded-904b-4e0c-85ce-a3617f5cb913]
2,homelessness city's greatest failure. message ...,"[83a09c6b-5f2f-421f-ae50-b38acca7e008, 74e2fab..."
3,minneapolis police officer kim potter guilty d...,"[83a09c6b-5f2f-421f-ae50-b38acca7e008, 9ff54de..."
4,judy update wildfires wiped entire neighborhoo...,"[9ff54ded-904b-4e0c-85ce-a3617f5cb913, 9a06646..."
...,...,...
5147,news. local help soon way. group volunteers yo...,"[9ff54ded-904b-4e0c-85ce-a3617f5cb913, 9632673..."
5148,"march 1, start called meteorological spring. k...",[9a06646a-e1df-4fca-888e-69658420556b]
5149,overseas. massive russian convoy headed toward...,[9ff54ded-904b-4e0c-85ce-a3617f5cb913]
5150,"morning, national hockey league says suspendin...","[9ff54ded-904b-4e0c-85ce-a3617f5cb913, b49207e..."


In [15]:
# Defining some key variables for training stage

MAX_LEN = 500
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 1e-05
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [16]:
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.body = dataframe.body
        self.targets = self.data.topic
        self.max_len = max_len

    def __len__(self):
        return len(self.body)

    def __getitem__(self, index):
        body = str(self.body[index])
        body = " ".join(body.split())

        inputs = self.tokenizer.encode_plus(
            body,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }