In [None]:
import os
from collections import namedtuple

from tqdm.notebook import tqdm

In [None]:
os.listdir('../input/sleepedf-lite-0')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

In [None]:
from sklearn.metrics import accuracy_score, f1_score

In [None]:
Config = namedtuple('Config', ['time_step', 'time_window', 'batch_size', 
                               'learning_rate', 'epochs', 'input_channel'])

In [None]:
args = Config(
    time_step = 12,
    time_window = 20480,
    batch_size = 64,
    learning_rate = 1e-3,
    epochs = 30,
    input_channel = 2
)

# Data Processing

In [None]:
class SleepEDFDataset(Dataset):
    def __init__(self, path, num_context, num_future, mode='train'):
        self.mode = mode
        self.num_context = num_context
        self.num_future = num_future
        
        data = np.load(path)
        self.data = np.concatenate((data['eeg_fpz_cz'].reshape(-1, 1, data['eeg_fpz_cz'].shape[-1]), data['eeg_pz_oz'].reshape(-1, 1, data['eeg_pz_oz'].shape[-1])), axis=1)
        self.targets = data['annotation'] - 1
        
    def __len__(self):
        return max(0, self.data.shape[0]-self.num_context-self.num_future)
        
    def __getitem__(self, item):
        if self.mode == 'train':
            return torch.from_numpy(self.data[item+self.num_context+1+self.num_future].astype(np.float32))
        else:
            return (
                    torch.from_numpy(self.data[item+self.num_context+1+self.num_future].astype(np.float32)),
                    torch.from_numpy(self.targets[item+self.num_context+1+self.num_future].astype(np.long))
                   )
        
    def __repr__(self):
        return f"""
               ****************************************
               Model  : {self.__class__.__name__}
               Length : {len(self)}
               ****************************************
                """

In [None]:
x_test = torch.from_numpy(x_test.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.long))
test_dataset = TensorDataset(x_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle=False)

# Backbones

## Common Layers

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(ResidualBlock, self).__init__()
        
        self.conv1d = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0, stride=1, bias=False),
            nn.BatchNorm(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=1, bias=False),
            nn.BatchNorm(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=1, padding=0, stride=1, bias=False)
        )
        
        self.shortcut = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0, stride=1, bias=False)
        )
        
    def forward(self, x):
        return self.conv1d(x) + self.shortcut(x)

## The Local Encoder

In [None]:
class LocalEncoder(nn.Module):
    def __init__(self, input_channel=2, time_length=3000, output_channel=16, output_length=128):
        super(LocalEncoder, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv1d(input_channel, output_channel, kernel_size=1, padding=0, stride=1, bias=False),
            nn.BatchNorm1d(output_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=13, stride=13),
            nn.Conv1d(output_channel, output_channel, kernel_size=1, padding=0, stride=1, bias=False),
            nn.BatchNorm1d(output_channel),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=13, stride=13),
            nn.Conv1d(output_channel, output_channel, kernel_size=1, padding=0, stride=1, bias=False),
        )
        
        self.avg_pool = nn.A
        
        
    def forward(self, x):
        out = self.layers(x)
        out = out.view(out.shape[0], -1)
        out = self.non_linear(out)
        out = self.linear(out)
        
        return out

In [None]:
class ContextEncoder(nn.Module):
    def __init__(self):
        super(ContextEncoder, self).__init__()
        
        self.gru = nn.GRU(input_size=output_length, hidden_size=gru_hidden_size, 
                          num_layers=gru_layers, dropout=gru_dropout)
        
    def forward(self):
        pass

In [None]:
class StatePredictor(nn.Module):
    def __init__(self):
        super(StatePredictor, self).__init__()
        
        self.
    
    def forward(self):
        pass

# The Sleep Contrast Model

In [None]:
class SleepContrast(nn.Module):
    def __init__(self, feature_dim, num_context, num_future):
        super(SleepContrast, self).__init__()
        
        self.feature_dim = feature_dim
        self.num_context = num_context
        self.num_future = num_future
        
        # Local Encoder
        self.encoder = LocalEncoder(input_channel, time_length, output_channel, output_length)
        
        # Memory bank
        memory_bank = torch.randn(total_size, output_length) 
        self.register_buffer('memory_bank', memory_bank)
        
        # Aggregator
        self.aggregator = ContextEncoder()
        en
        # Predictor
        self.predictor = StatePredictor()
        
    def _initialize_weights(self, module):
        for name, param in module.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0)
            elif 'weight' in name:
                nn.init.orthogonal_(param, 0.1)
        
    def compute_loss(self):
        pass
        
    def forward(self, x):
        # Get feautres
        # x: (batch, num, channel, length)
        (batch, num, channel, length) = x.shape
        x = x.view(batch*num, channel, length)
        feature = self.encoder(x)
        feature = feature.view(batch, num, self.feature_dim)
        
        context = feature[:, :self.num_context+1, :].contiguous()
#         anchor = feature[:, num_context, :].contiguous()
        future = feature[:, -self.num_future, :].contiguous()
        
        # Predictions
        

In [None]:
model = CPC(args.time_step, args.batch_size, args.time_window, in_channel=args.input_channel)
model.cuda()

In [None]:
optimizer = optim.Adam(filter(lambda p : p.requires_grad, model.parameters()), 
                       lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-09, 
                       weight_decay=1e-4, amsgrad=True)

In [None]:
model.train()
for epoch in range(args.epochs):
    acc_list = []
    loss_list = []
    with tqdm(train_dataloader, desc=f'EPOCH: [{epoch+1}/{args.epochs}]') as progress_bar:
        for x, y, idx in progress_bar:
            x, y, idx = x.cuda(), y.cuda(), idx.cuda()
            
            optimizer.zero_grad()
            hidden = model.init_hidden(len(x), use_gpu=True)
            acc, loss, hidden = model(x, hidden)
            
            loss.backward()
            optimizer.step()
            
            acc_list.append(acc.item())
            loss_list.append(loss.item())
            
            progress_bar.set_postfix({'loss': np.mean(loss_list), 'acc': np.mean(acc_list)})