## Imports

In [13]:
from typing import Callable, List, Optional
import torch
import pickle
import os
import copy
from tqdm.auto import tqdm
import gdown
import typing as tp
from sentence_transformers import SentenceTransformer
import pickle

from torch_geometric.data import Data, InMemoryDataset, Dataset, download_url
from torch_geometric.loader import DataLoader

from embedding import create_embedding

## InMemoryDataset 

In [6]:
class IMCAG(InMemoryDataset):
    url = 'https://drive.google.com/uc?id=1QH2WNnx4X7Qm6kDgG6Ry8c5fNSGnjZsI'
    # short_url = ''
  
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        # The name of the files in the self.raw_dir folder that must be present in order to skip downloading.
        return ['data_v2_0.pickle']

    @property
    def processed_file_names(self):
        # The name of the files in the self.processed_dir folder that must be present in order to skip processing.
        return ['data.pt']

    # def download(self):
    #     for f in self.raw_file_names:
    #         download_url(os.path.join(self.url, f), self.raw_dir)
    
    def download(self):
        gdown.download(self.url, os.path.join(self.raw_dir, self.raw_file_names[0]), quiet=True)
    
    def load_pickle(self, path: str):
        with open(path, 'rb') as f:
            data = pickle.load(f)
        return data

    def process(self):
        data_list = []
        files = [f for f in os.listdir(self.raw_dir) if not os.path.isdir(f)]
        for f in files:
            data_list.extend(self.load_pickle(os.path.join(self.raw_dir, f)))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

## DataLoader

In [None]:
class GraphDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding = create_embedding()
        
    def preprocess(self, batch):
        batch = copy.copy(batch)
        x = batch.x
        for i in range(len(x)):
            x[i] = self.embedding(x[i])
        batch.x = torch.vstack(x)
        return batch
    
    def __iter__(self):
        self.iterator = super().__iter__()   
        return self
    
    def __next__(self):
        batch = next(self.iterator)
        return self.preprocess(batch)

## Check

In [7]:
dataset = IMCAG('data/')

Processing...
Done!


In [83]:
len(dataset)

128039

In [99]:
dataset = dataset.shuffle()

In [101]:
train_dataset = dataset[:800]
test_dataset = dataset[800:1001]

In [102]:
train_loader = GraphDataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = GraphDataLoader(test_dataset, batch_size=32)

## NN

In [103]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from embedding import create_embedding  

In [104]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(384, 512)
        self.conv2 = GCNConv(512, 64)
        self.lin = Linear(64, 2)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)
        return x

In [107]:
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
criterion = torch.nn.CrossEntropyLoss()

In [96]:
def train():
    model.train()
    for data in tqdm(train_loader, total=25):
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step() 
        optimizer.zero_grad()  
    

def test(loader):
    model.eval()
    avg_loss = 0
    correct = 0
    for i, data in enumerate(loader):  
        out = model(data)  
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum()) 
        avg_loss += criterion(out, data.y)
        
    avg_loss /= i
    return correct / len(loader.dataset), avg_loss


for epoch in range(1, 10):
    train()
    train_acc, avg_train_loss = test(train_loader)
    test_acc, avg_test_loss = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}, Train loss: {avg_train_loss:.6f}, Test loss: {avg_test_loss:.6f}')

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 001, Train Acc: 0.7125, Test Acc: 0.7164, Train loss: 0.508731, Test loss: 0.588494


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 002, Train Acc: 0.8363, Test Acc: 0.8408, Train loss: 0.437841, Test loss: 0.543081


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 003, Train Acc: 0.8975, Test Acc: 0.9055, Train loss: 0.413417, Test loss: 0.466652


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 004, Train Acc: 0.9000, Test Acc: 0.8905, Train loss: 0.343656, Test loss: 0.467784


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 005, Train Acc: 0.9050, Test Acc: 0.9055, Train loss: 0.341485, Test loss: 0.388448


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 006, Train Acc: 0.9050, Test Acc: 0.9055, Train loss: 0.346074, Test loss: 0.388801


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 007, Train Acc: 0.8812, Test Acc: 0.8109, Train loss: 0.330465, Test loss: 0.398121


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 008, Train Acc: 0.9038, Test Acc: 0.9005, Train loss: 0.329161, Test loss: 0.402637


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 009, Train Acc: 0.9025, Test Acc: 0.8955, Train loss: 0.323599, Test loss: 0.393539


## Зашафленный датасет, lr = 0.03

In [106]:
def train():
    model.train()
    for data in tqdm(train_loader, total=25):
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step() 
        optimizer.zero_grad()  
    

def test(loader):
    model.eval()
    avg_loss = 0
    correct = 0
    for i, data in enumerate(loader):  
        out = model(data)  
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum()) 
        avg_loss += criterion(out, data.y)
        
    avg_loss /= i
    return correct / len(loader.dataset), avg_loss


for epoch in range(1, 10):
    train()
    train_acc, avg_train_loss = test(train_loader)
    test_acc, avg_test_loss = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}, Train loss: {avg_train_loss:.6f}, Test loss: {avg_test_loss:.6f}')

  0%|          | 0/25 [00:00<?, ?it/s]



Epoch: 001, Train Acc: 0.6650, Test Acc: 0.7363, Train loss: 0.509741, Test loss: 0.517175


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 002, Train Acc: 0.8462, Test Acc: 0.8358, Train loss: 0.460329, Test loss: 0.460935


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 003, Train Acc: 0.8488, Test Acc: 0.8358, Train loss: 0.465189, Test loss: 0.441848


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 004, Train Acc: 0.7950, Test Acc: 0.7363, Train loss: 0.459212, Test loss: 0.505561


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 005, Train Acc: 0.7963, Test Acc: 0.7363, Train loss: 0.445718, Test loss: 0.494289


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 006, Train Acc: 0.8375, Test Acc: 0.8209, Train loss: 0.436312, Test loss: 0.481490


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 007, Train Acc: 0.8063, Test Acc: 0.7562, Train loss: 0.412358, Test loss: 0.506198


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 008, Train Acc: 0.7712, Test Acc: 0.7662, Train loss: 0.472222, Test loss: 0.449753


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 009, Train Acc: 0.8475, Test Acc: 0.8060, Train loss: 0.403113, Test loss: 0.463352
