In [0]:
import torch

import time
import csv

Below are some helpers, and a few lines to nudge you to change your Colab runtime to GPU (just in case you hadn't already).

In [0]:
def tlog(msg):
    print('{}   {}'.format(time.asctime(), msg))


# If possible, we should be running on GPU
if not torch.cuda.is_available():
    print('If you are running this notebook in Colab, go to the Runtime menu and select "Change runtime type" to switch to GPU.')
else:
    print('GPU ready to go!')

In [0]:
# indices
I_PHRASE_ID = 0
I_SENTENCE_ID = 1
I_PHRASE = 2
I_LABEL = 3

class RottenTomatoesDataset(torch.nn.Module):
    def __init__(self):
        raw_rows = []
        with open('train.tsv') as tsvfile:
            tlog('Loading training data...')
            reader = csv.reader(tsvfile, delimiter='\t')
            count = 0
            exceptions = 0
            max_sentence = 0
            for row in reader: # assuming sorted by sentenceid, phraseid
                if count > 0: # skip header
                    phraseID = int(row[I_PHRASE_ID])
                    sentenceID = int(row[I_SENTENCE_ID])
                    label = int(row[I_LABEL])
                    if phraseID > 0 and sentenceID > 0 and label > 0:
                        row[I_PHRASE_ID] = phraseID
                        row[I_SENTENCE_ID] = sentenceID
                        row[I_LABEL] = label
                        raw_rows.append(row)
                        max_sentence = max(max_sentence, sentenceID)
                    else:
                        exceptions += 1
                count += 1
            
            # break into training & validation
            for i in range(len(raw_rows)):
                if (max_sentence * 0.8) > raw_rows[i][I_SENTENCE_ID]:
                    self.training_rows, self.validation_rows = raw_rows[:i], raw_rows[i:]
            self.training = True
            tlog('Finished loading training data')
    
    def training(self):
        self.training = True
    
    def validating(self):
        self.training = False
    
    def current_dataset(self):
        if self.training:
            return self.training_rows
        return self.validation_rows

    def __len__(self):
        return len(self.current_dataset())
    
    def __getitem__(self, idx):
        row = self.current_dataset()[idx]
        return (row[I_PHRASE_ID], row[I_SENTENCE_ID], row[I_PHRASE]), row[I_LABEL]


In [0]:
dataset = RottenTomatoesDataset()
dataset.training()
print(len(dataset))
dataset.validating()
print(len(dataset))

Wed Apr 24 00:13:41 2019   Loading training data...
