In [None]:
"""
!pip install -U torch torchtext torchdata
"""

In [71]:
# import statements
import pandas as pd

import torch

import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

In [72]:
def get_conf():
    conf = {
        "path": "/Users/jaydeepchakraborty/JC/git-projects/model_util/",
        "data":{
            "train_data_path": "DataSets/NLPwithDisasterTweets/modf_train_data.csv",
            "test_data_path": "DataSets/NLPwithDisasterTweets/modf_test_data.csv",
            "vocab_path": "DataSets/NLPwithDisasterTweets/disaster_tweets.pt",
            "train_dataset": "DataSets/NLPwithDisasterTweets/train_dataset.pt",
            "validation_dataset": "DataSets/NLPwithDisasterTweets/validation_dataset.pt",
            "test_dataset": "DataSets/NLPwithDisasterTweets/test_dataset.pt",
        }
    }
    
    return conf

In [73]:
class DisasterTweetsDataSet(torch.utils.data.Dataset):
    def __init__(self, conf, ind="train"):
        self.conf = conf
        if ind == "train":
            self.data = pd.read_csv(self.conf['path'] + self.conf['data']['train_data_path'])
            self.data = self.data.astype({"id": 'int64', "keyword": 'string', "location": 'string', "text": 'string', "target": 'int64'})
        if ind == "test":
            self.data = pd.read_csv(self.conf['path'] + self.conf['data']['test_data_path'])
            self.data = self.data.astype({"id": 'int64', "keyword": 'string', "location": 'string', "text": 'string'})
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data.iloc[idx]

In [74]:
class DisasterTweetsDataSetHelper:
    
    def __init__(self, conf):
        
        self.tokenizer = get_tokenizer("basic_english")
        
        self.conf = conf
        
        self.dataset = None
        self.vocab = None
        self.train_dataset = None 
        self.validation_dataset = None
        self.test_dataset = None   
        
    def load_data(self):
        # FILE_PATH = conf['path'] + conf['data']['train_data_path']
        # self.data_pipe = dp.iter.IterableWrapper([FILE_PATH])
        # self.data_pipe = dp.iter.FileOpener(data_pipe, mode='rb')
        # self.data_pipe = self.data_pipe.parse_csv(skip_lines=1, as_tuple=True)
        
        # loading entire data
        self.dataset = DisasterTweetsDataSet(self.conf, ind="train")
        
        # loading test data
        self.test_dataset = DisasterTweetsDataSet(self.conf, ind="test")
        
    def split_data(self):
        # train, validation split
        split_dataset = torch.utils.data.random_split(self.dataset, [0.8, 0.2])
        
        self.train_dataset, self.validation_dataset = split_dataset[0], split_dataset[1]
        
    def save_data(self, ind="train"):
        
        if ind=="train" and self.train_dataset:
            torch.save(self.train_dataset, self.conf['path'] + conf['data']['train_dataset'])       
        elif ind=="validation" and self.validation_dataset:
            torch.save(self.validation_dataset, self.conf['path'] + conf['data']['validation_dataset'])     
        elif ind=="test" and self.test_dataset:
            torch.save(self.test_dataset, self.conf['path'] + conf['data']['test_dataset'])
            
    def load_saved_data(self, ind="train"):
        if ind=="train":
            self.train_dataset = torch.load(self.conf['path'] + conf['data']['train_dataset']) 
            print(f"loaded train_dataset length {len(self.train_dataset)}")
        elif ind=="validation":
            self.validation_dataset = torch.load(self.conf['path'] + conf['data']['validation_dataset'])
            print(f"loaded validation_dataset length {len(self.validation_dataset)}")
        elif ind=="test":
            self.test_dataset = torch.load(self.conf['path'] + conf['data']['test_dataset'])
            print(f"loaded test_dataset length {len(self.test_dataset)}")
        
        
    def yield_tokens(self, data_iter):
        for data_val in data_iter:
            data_keyword, data_loc, data_text = str(data_val[1]), str(data_val[2]), str(data_val[3])
            yield self.tokenizer(data_keyword + " " + data_loc + " " + data_text)
    
    def gen_vocab(self):
        
        self.vocab = build_vocab_from_iterator(self.yield_tokens(self.dataset),
                                                min_freq=2,max_tokens=20000, 
                                                specials= ['<pad>', '<sos>', '<eos>', '<unk>'], 
                                                special_first=True
                                            )
        self.vocab.set_default_index(self.vocab['<unk>'])
        
        print("Vocab generation:- ")
        print("The length of the new vocab is", len(self.vocab))    
        new_stoi = self.vocab.get_stoi()
        print("The index of 'new' is", new_stoi['new'])
        print("The index of '<pad>' is", new_stoi['<pad>'])
        print("The index of '<sos>' is", new_stoi['<sos>'])
        print("The index of '<eos>' is", new_stoi['<eos>'])
        print("The index of '<unk>' is", new_stoi['<unk>'])
        new_itos = self.vocab.get_itos()
        print("The token at index 17 is", new_itos[17])
        
    def save_vocab(self):
        torch.save(self.vocab, self.conf['path'] + conf['data']['vocab_path'])
        
    def load_vocab(self):
        self.vocab = torch.load(self.conf['path'] + conf['data']['vocab_path'])
        print("Vocab loading:- ")
        print("The length of the new vocab is", len(self.vocab))    
        new_stoi = self.vocab.get_stoi()
        print("The index of 'new' is", new_stoi['new'])
        print("The index of '<pad>' is", new_stoi['<pad>'])
        print("The index of '<sos>' is", new_stoi['<sos>'])
        print("The index of '<eos>' is", new_stoi['<eos>'])
        print("The index of '<unk>' is", new_stoi['<unk>'])
        new_itos = self.vocab.get_itos()
        print("The token at index 17 is", new_itos[17])
        
            
    def prnt_sample_data(self, ind="train"):
        if ind == "train":
            print("sample train data")
            cnt = 0
            for data in self.train_dataset:
                print(data)
                cnt += 1
                if cnt == 3: break
        if ind == "validation":
            print("sample validation data")
            cnt = 0
            for data in self.validation_dataset:
                print(data)
                cnt += 1
                if cnt == 3: break         
        if ind == "test":
            print("sample test data")
            cnt = 0
            for data in self.test_dataset:
                print(data)
                cnt += 1
                if cnt == 3: break
                
    
    def populate_disaster_tweets(self):
        # loading the entire and test data 
        self.load_data()
        
        # spliting the data into train and validation
        self.split_data()
        
        # saving the train dataset, validation dataset, test dataset
        self.save_data(ind="train")
        self.save_data(ind="validation")
        self.save_data(ind="test")
        
        # loading the saved dataset
        self.load_saved_data(ind="train")
        self.load_saved_data(ind="validation")
        self.load_saved_data(ind="test")
        
        # generating vocabulary
        self.gen_vocab()
        
        # saving vocabulary
        self.save_vocab()
        
        # loading the saved vocabulary
        self.load_vocab()
        
        # printing sample data
        self.prnt_sample_data(ind="train")
        self.prnt_sample_data(ind="validation")
        self.prnt_sample_data(ind="test")
        
    

In [75]:
if __name__ == "__main__":
    
    # configuration
    conf = get_conf()
    
    disaster_tweets_obj = DisasterTweetsDataSetHelper(conf)
    disaster_tweets_obj.populate_disaster_tweets()

loaded train_dataset length 6019
loaded validation_dataset length 1504
loaded test_dataset length 3263
Vocab generation:- 
The length of the new vocab is 1039
The index of 'new' is 7
The index of '<pad>' is 0
The index of '<sos>' is 1
The index of '<eos>' is 2
The index of '<unk>' is 3
The token at index 17 is via
Vocab loading:- 
The length of the new vocab is 1039
The index of 'new' is 7
The index of '<pad>' is 0
The index of '<sos>' is 1
The index of '<eos>' is 2
The index of '<unk>' is 3
The token at index 17 is via
sample train data
id                    4823
keyword         evacuation
location              <NA>
text        war evacuation
target                   1
Name: 3327, dtype: object
id                               10241
keyword                        volcano
location                      northern
text        trying blod volcano htptco
target                               0
Name: 7058, dtype: object
id                                               2110
keyword             