In [141]:
import os
import pandas as pd
import numpy as np
import torch
import pickle
import onnx
import torchtext
from dotenv import load_dotenv
from sqlalchemy import create_engine
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import build_vocab_from_iterator

In [142]:
def get_data():
  engine_name = f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}"
  engine = create_engine(engine_name)

  query_sql = "SELECT * FROM fb_data"
  df = pd.read_sql_query(query_sql, con=engine)
  return df


In [143]:
class FBData(Dataset):
  def __init__(self,df):

    tokenizer = get_tokenizer('basic_english')
    def yield_tokens(data_iter):
      for text in data_iter:
        yield tokenizer(text)
    vocab = build_vocab_from_iterator(yield_tokens(df['text']), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])
    
    df.drop(columns=['id'])
    
    self.data = df
    self.vocab = vocab
    self.tokenizer = tokenizer
  
  def __len__(self):
    return len(self.data)

  def __getitem__(self,idx):
    text = self.data['text'][idx]
    label = self.data['source'][idx]

    
    tokenized_text = self.vocab(self.tokenizer(text))

    SOURCES = {
      0: ["nytimes","cnn","nbc"],
      1: ["FoxNews","DailyMail","NYPost"],
      2: ["bbcnews","Reuters","APNews"]
    }

    for class_label, sources in SOURCES.items():
        if label in sources:
            label = class_label
            break
    return torch.tensor(tokenized_text),torch.tensor(label)

def collate_batch(batch):
  text_list=[]
  label_list=[]
  for text,label in batch:
    text_list.append(text)
    label_list.append(label)

  padded_texts = pad_sequence(text_list,batch_first=True,padding_value=0)
  return padded_texts, label_list

In [144]:
def main():
  df = get_data()
  dataset = FBData(df)
  dataloader = DataLoader(dataset,batch_size=2,shuffle=True,collate_fn=collate_batch)
main()
  

tensor([[   2, 2224,  661, 4452, 2415,    2,   48,  113,   59,    5,   97,  865,
          198,  286,   25, 3615, 2140,    4, 1456,  403,    8,  963, 4443,    1],
        [   2,  229,   34,   77,  568,  264,   31, 2520,   17,  511,   32,    2,
           34,  374,    7,  204,    1,    0,    0,    0,    0,    0,    0,    0]])
[tensor(0), tensor(1)]
tensor([[  10,  138,   40,    9,   23,   10,   24,   15,  190,   60,   66,  809,
         1396,  423,  100,   40, 1321,   17, 2293,  322,    1,   15,    9,   40,
          218,  310,    7, 2130,   45,   84,   75,  740,    4,  640, 2140,    4,
         1456,   13, 1075, 4803, 1005,  403,    1,  133,   46,   18,   70,    1,
         4203],
        [  24,   15,   22,  539,    2,   35,   34,  221,    3,  456,    5,  285,
          423,   32,   28, 2519,  118,    1,  169,   54,   94,   45,    2,   35,
           52,   58,   18,   70,    1, 4194,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0