In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import torchvision
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader,Dataset, random_split
import matplotlib.pyplot as plt
import random 
import time
from tqdm import tqdm_notebook as tq
import warnings
import pickle as pkl
warnings.filterwarnings("ignore")
import string
import sys
from nltk.corpus import stopwords
plt.ion()

In [2]:
BATCH_SIZE = 32
epochs = 1

In [24]:
device = torch.device("gpu" if torch.cuda.is_available() else "cpu")

In [25]:
class YelpReviewsSentimentAnalysis(nn.Module):
    
    def __init__(self,vocab_size,num_class):
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [58]:
class YelpDataset(Dataset):
    
    def __init__(self,json_file,threshold=3):
        
        self.raw_data = pd.read_json(json_file,lines=True)
    
        self.raw_data['label'] = self.raw_data.stars.apply(lambda x : 1 if x>=3 else 0)
        
        self.raw_data = self.raw_data[["label","text"]].iloc[:1000,]
        
        self.word2idx = {}
        
        self.idx2word = {}
        
        self.word2freq = {}
        self.word_count = 0
        
        self.maxLen = 0
        self.__init__preprocess()
        
        self.data = self.raw_data.to_numpy()
   
        
       
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self,idx):
        
        
        sample = self.data[idx,:]
            
        return sample
    
    def __init__clean(self):
        
        
        def clean(text):
            text = text.lower()
        
            text = [ch for ch in text if ch not in string.punctuation]


            text = "".join(text)

            text = [c for c in text if c == " " or c.isalnum()]

            text = "".join(text)

            stop_words = set(stopwords.words("english"))

            text = text.split(" ")

            text = [word for word in text if word not in stop_words]
        
        
            text = " ".join(text)
            
            return text
          
            
        def build_vocab(text):
            
            
            text = text.split(" ")
            
            text_token = []
            
            for word in text:
                
                if word not in self.word2idx:
                    
                    self.word2idx[word] = self.word_count
                    
                    self.idx2word[self.word_count] = word
                    
                    self.word_count+=1
                    
                    
                text_token.append(self.word2idx[word])
             
            self.maxLen = max(self.maxLen,len(text_token))
            return text_token
        
        self.raw_data['text'] = self.raw_data.text.apply(lambda x : clean(x))
        self.raw_data['text'] = self.raw_data.text.apply(lambda x : build_vocab(x))

In [59]:
def get_batch(batch):
    
    
    pass

In [60]:
def train_func(train):
    
    train_data = DataLoader(train,batch_size=BATCH_SIZE,shuffle=True,collate_fn=get_batch)
    
    #print(train)

    for i,(label,text) in enumerate(train_data):
        print(i,text)
        

In [61]:
if __name__ == "__main__":
    
    
    
    start_time = time.time()
    yelp_dataset = YelpDataset(json_file = "~/data/yelp/review.json")
    
    train_len = int(len(yelp_dataset)*0.8)
    
    valid_len = int(len(yelp_dataset)*0.1)
    
    test_len = len(yelp_dataset) - train_len -valid_len
    
    train,valid,test = random_split(yelp_dataset,[train_len,valid_len,test_len])
    
    print(train[2])
    
    print(len(train))
    
    print("time taken {}".format(time.time()-start_time))
    

[1
 'best chinese resto highly recommended 5 stars let us support business best west valley trust wont gi wrong coming place'
 list([89, 499, 1269, 188, 1270, 291, 598, 359, 9, 1271, 397, 89, 1272, 1273, 1274, 383, 1275, 812, 368, 454])]
80
time taken 194.68570709228516


In [41]:
yelp_dataset[2]

array([1,
       'say office really together organized friendly  dr j phillipp great dentist friendly professional  dental assistants helped procedure amazing jewel bailey helped feel comfortable  dont dental insurance insurance office purchase 80 something year gave 25 dental work plus helped get signed care credit knew nothing visit  highly recommend office nice synergy whole office'],
      dtype=object)

In [42]:
yelp_dataset.raw_data

Unnamed: 0,label,text
0,0,total bill horrible service 8gs crooks actuall...
1,1,adore travis hard rocks new kelly cardenas sal...
2,1,say office really together organized friendly ...
3,1,went lunch steak sandwich delicious caesar sal...
4,0,today second three sessions paid although firs...
...,...,...
95,1,love crust pizza sauce decent cheese okay pepp...
96,0,normally give restaurant least 3 stars long go...
97,1,im familiar scottsdale im guessing restaurant ...
98,1,good morning cocktails waitwhat ohits vegasdin...


In [53]:
vocab = Vocab(yelp_dataset.raw_data.text,min_freq=2)

KeyError: '<unk>'