<a href="https://colab.research.google.com/github/krisdmitrieva/DL_HW/blob/main/DL_lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [2]:
from torch.utils.data import DataLoader
import torch
import pandas
from datasets import load_dataset_builder
from datasets import load_dataset

In [3]:
torch.manual_seed(14)

ds_builder = load_dataset_builder("conll2000")
dataset = load_dataset("conll2000", split="train")
print(dataset)

Downloading builder script:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.77k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/612k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/140k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/8937 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2013 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'tokens', 'pos_tags', 'chunk_tags'],
    num_rows: 8937
})


In [4]:
ix_tag_ref = {'0': "''''",
            '1': "#",
            '2': "$",
            '3': "(",
            '4': ")",
            '5': ",",
            '6': ".",
            '7': ":",
            '8': "``",
            '9': "CC",
            '10':"CD",
            '11':"DT",
            '12': "EX",
            '13': "FW",
            '14': "IN",
            '15': "JJ",
            '16': "JJR",
            '17': "JJS",
            '18': "MD",
            '19': "NN",
            '20': "NNP",
            '21': "NNPS",
            '22': "NNS",
            '23': "PDT",
            '24': "POS",
            '25': "PRP",
            '26': "PRP$",
            '27': "RB",
            '28': "RBR",
            '29': "RBS",
            '30': "RP",
            '31': "SYM",
            '32': "TO",
            '33': "UH",
            '34': "VB",
            '35': "VBD",
            '36': "VBG",
            '37': "VBN",
            '38': "VBP",
            '39': "VBZ",
            '40': "WDT",
            '41': "WP",
            '42': "WP$",
            '43': "WRB"}

In [5]:
sentences = dataset['tokens'][0:1000]

In [6]:
tags = dataset['pos_tags'][0:1000]

In [7]:
def merge(list1, list2):

    merged_list = [(list1[i], list2[i]) for i in range(0, len(list1))]

    return merged_list

In [8]:
training_data = merge(sentences, tags)

In [9]:
word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:  # word has not been assigned an index yet
            word_to_ix[word] = len(word_to_ix)

In [10]:
len(word_to_ix)

4920

In [11]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 100

In [12]:
class LSTMTagger(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(4920, EMBEDDING_DIM)
        self.lstm = torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM)

        self.pos_predictor = torch.nn.Linear(HIDDEN_DIM, 44)

    def forward(self, token_ids):
        embeds = self.embedding_layer(token_ids)
        lstm_out, _ = self.lstm(embeds.view(len(token_ids), 1, -1))
        logits = self.pos_predictor(lstm_out.view(len(token_ids), -1))
        probs = torch.nn.functional.softmax(logits, dim=1)

        return probs

In [13]:
model = LSTMTagger()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

In [14]:
for epoch in range(15):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Tensors of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = torch.tensor(tags, dtype=torch.long)

        # Step 3. Run our forward pass.
        tag_scores = model(sentence_in)

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()

In [15]:
# See what the scores are after training
with torch.no_grad():
    inputs = prepare_sequence(training_data[0][0], word_to_ix)
    tag_scores = model(inputs)

    print(tag_scores)

tensor([[1.7139e-03, 1.9760e-03, 1.6397e-03,  ..., 1.9745e-03, 1.8432e-03,
         1.6271e-03],
        [6.2944e-05, 7.5199e-05, 5.1318e-05,  ..., 8.1260e-05, 8.1851e-05,
         5.7030e-05],
        [2.8506e-08, 4.2454e-08, 3.8262e-08,  ..., 5.0449e-08, 4.5811e-08,
         3.4257e-08],
        ...,
        [5.1674e-07, 6.2593e-07, 5.2052e-07,  ..., 5.2888e-07, 5.8917e-07,
         3.1584e-07],
        [1.5488e-05, 1.7044e-05, 1.9372e-05,  ..., 1.5559e-05, 1.6380e-05,
         8.8174e-06],
        [4.9177e-07, 5.0563e-07, 3.1183e-07,  ..., 3.0457e-07, 4.9809e-07,
         2.4961e-07]])
