In [1]:
# utils 
import torch

# data 
from torchtext.datasets import imdb
from torchtext.data import Field, BucketIterator

# model 
import torch.nn as nn
import torch.nn.functional as F

# training 
import torch.optim as optim
import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# device

### Data Preparation

In [9]:
text = Field(
    lower=True,
    tokenize="spacy",
    batch_first=True,
)
label = Field(
    is_target=True
)

In [10]:
# download the dataset
train, val = imdb.IMDB.splits(text_field=text, label_field=label)

In [11]:
# build vocabulary
text.build_vocab(train, min_freq=2)
label.build_vocab(train)

In [8]:
# create data loaders
BATCH_SIZE = 64

train_loader, val_loader = BucketIterator.splits(
    datasets=(train, val),
    batch_sizes=(BATCH_SIZE, BATCH_SIZE),
    device=device
)

### Model

In [8]:
class GRU(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_size, n_classes = 2, dropout = 0.15, num_layers = 4):
        
        super(GRU, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.gru = nn.GRU(
            input_size = embedding_dim, 
            hidden_size = hidden_size, 
            num_layers = num_layers,
            dropout = dropout,
            bidirectional = True,
            batch_first =  True
        )
        
    
    
    def forward(self, x, hidden=None):
        # a better initialization of initial hidden state could be very useful
        embedded = self.embedding(x)
        
        outputs, hidden = self.gru(embedded)
        
        return outputs, hidden
        

In [9]:
model = GRU(
    vocab_size=len(text.vocab),
    embedding_dim=100,
    hidden_size=64
).to(device)

In [80]:
outputs, hidden = model(x)

RuntimeError: CUDA out of memory. Tried to allocate 984.00 MiB (GPU 0; 5.80 GiB total capacity; 4.40 GiB already allocated; 583.31 MiB free; 4.43 GiB reserved in total by PyTorch)

In [81]:
print(outputs.shape, hidden.shape)

torch.Size([64, 1158, 128]) torch.Size([16, 64, 64])
