In [None]:
import os
from collections import namedtuple

from tqdm.notebook import tqdm
try:
    from rich.progress import track
except:
    !pip install rich
    from rich.progress import track

In [None]:
os.listdir('../input/sleepedf-lite-0')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

In [None]:
from sklearn.metrics import accuracy_score, f1_score

In [None]:
Config = namedtuple('Config', ['seq_len', 'input_channels', 'hidden_channels', 'stride', 
                               'batch_size', 'num_seq', 'pred_steps', 'feature_dim',
                               'learning_rate', 'train_ratio', 'epochs', 'save_path',
                               'num_classes', 'finetune_ratio', 'finetune_epochs'])

In [None]:
args = Config(
    seq_len=20,
    stride=1,
    input_channels=2,
    hidden_channels=16,
    batch_size=16,
    num_seq=20,
    pred_steps=5,
    feature_dim=128,
    learning_rate=1e-3,
    train_ratio=0.7,
    epochs=10,
    save_path='/kaggle/working/check_points/',
    num_classes=5,
    finetune_ratio=0.1,
    finetune_epochs=10
)

# Data Preparing

In [None]:
class SleepEDFDataset(Dataset):
    def __init__(self, path, seq_len, stride=1, patients=2, return_label=False):
        self.return_label = return_label
        self.seq_len = seq_len
        
        assert os.path.exists(path)
        file_names = os.listdir(path)
        
        candidate_data = []
        candidate_target = []
        
        for filename in file_names[:patients]:
            data = np.load(os.path.join(path, filename))
            candidate_data.append(
                np.concatenate(
                    (data['eeg_fpz_cz'].reshape(-1, 1, data['eeg_fpz_cz'].shape[-1]), 
                     data['eeg_pz_oz'].reshape(-1, 1, data['eeg_pz_oz'].shape[-1])), 
                axis=1)
            )
            candidate_target.append(data['annotation'] - 1)
        candidate_data = np.concatenate(candidate_data, axis=0)
        candidate_target = np.concatenate(candidate_target, axis=0)
        
        self.data = []
        self.targets = []
        for i in tqdm(range(0, len(candidate_data), stride)):
            if (i + seq_len > len(candidate_data)):
                break
            self.data.append(np.expand_dims(candidate_data[i: i + seq_len], axis=0))
            self.targets.append(np.expand_dims(candidate_target[i: i + seq_len], axis=0))
        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.concatenate(self.targets, axis=0)
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, item):
        if self.return_label:
            return (
                torch.from_numpy(self.data[item].astype(np.float32)), 
                torch.from_numpy(self.targets[item].astype(np.long))
            )
        else:
            return torch.from_numpy(self.data[item].astype(np.float32))
        
    def __repr__(self):
        return f"""
               ****************************************
               Model  : {self.__class__.__name__}
               Length : {len(self)}
               ****************************************
                """

# Backbones

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_sizes=[7, 11, 7], stride=1, dropout=0.2):
        super(ResidualBlock, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        
        assert len(kernel_sizes) == 3

        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, output_channels, kernel_size=kernel_sizes[0], stride=1, 
                      padding=kernel_sizes[0]//2, bias=False),
            nn.BatchNorm1d(output_channels),
            nn.ReLU(inplace=True)
        )

        # Only conv2 degrades the scale
        self.conv2 = nn.Sequential(
            nn.Conv1d(output_channels, output_channels, kernel_size=kernel_sizes[1], stride=stride, 
                      padding=kernel_sizes[1]//2, bias=False),
            nn.BatchNorm1d(output_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(output_channels, output_channels, kernel_size=kernel_sizes[2], stride=1, 
                      padding=kernel_sizes[2]//2, bias=False),
            nn.BatchNorm1d(output_channels),
        )

        self.relu = nn.ReLU(inplace=True)


        # If stride == 1, the length of the time dimension will not be changed
        # If input_channels == output_channels, the number of channels will not be changed
        # If the channels are mismatch, the conv1d is used to upgrade the channel
        # If the time dimensions are mismatch, the conv1d is used to downsample the scale
        self.downsample = nn.Sequential()
        if stride != 1 or input_channels != output_channels:
            self.downsample = nn.Sequential(
                nn.Conv1d(input_channels, output_channels, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm1d(output_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        # Downsampe is an empty list if the size of inputs and outputs are same
        residual = self.downsample(x) 
        out += residual
        out = self.relu(out)
        
        return out

In [None]:
class ResNet(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_classes, kernel_sizes=[7, 11, 7]):
        super(ResNet, self).__init__()

        # The first convolution layer
#         self.conv1 = nn.Sequential(
#             nn.Conv1d(input_channels, hidden_channels, kernel_size=15, stride=2, padding=7, bias=False),
#             nn.BatchNorm1d(hidden_channels),
#             nn.ReLU(inplace=True),
#             nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
#         )
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(inplace=True)
        )

        # Residual layers
        self.layer1 = self.__make_layer(ResidualBlock, hidden_channels, hidden_channels, 2, kernel_sizes, stride=1)
        self.layer2 = self.__make_layer(ResidualBlock, hidden_channels, hidden_channels*2, 2, kernel_sizes, stride=2)
        self.layer3 = self.__make_layer(ResidualBlock, hidden_channels*2, hidden_channels*4, 2, kernel_sizes, stride=2)
        self.layer4 = self.__make_layer(ResidualBlock, hidden_channels*4, hidden_channels*8, 2, kernel_sizes, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool1d(1) # Pooling operation computes the average of the last dimension (time dimension)

        # A dense layer for output
        self.fc = nn.Linear(hidden_channels*8, num_classes)

        # Initialize weights
#         for m in self.modules():
#             if isinstance(m, nn.Conv1d):
#                 n = m.kernel_size[0] * m.kernel_size[0] * m.out_channels
#                 m.weight.data.normal_(0, math.sqrt(2. / n))
#             elif isinstance(m, nn.BatchNorm1d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()

    def __make_layer(self, block, input_channels, output_channels, num_blocks, kernel_sizes, stride):
        layers = []
        layers.append(block(input_channels, output_channels, kernel_sizes, stride=stride))
        for i in range(1, num_blocks):
            layers.append(block(output_channels, output_channels, stride=1))        
        return nn.Sequential(*layers)

    def forward(self, x):
        """
        L_out = floor[(L_in + 2*padding - kernel) / stride + 1]
        """
        out = self.conv1(x)          
        out = self.layer1(out)     
        out = self.layer2(out)  
        out = self.layer3(out) 
        out = self.layer4(out)    

        out = self.avg_pool(out)
        out = out.view(x.size(0), -1)
        out = self.fc(out)

        return out

In [None]:
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.3):
        super(GRU, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, 
                          batch_first=True, dropout=dropout)
        
    def forward(self, x, h_0):
        # x:   (batch, seq_len,    input_size)
        # h_0: (num_layers, batch, hidden_size)
        
        out, h_n = self.gru(x, h_0)
        
        # out: (batch, seq_len, hidden_size)
        # h_n: (num_layers, batch, hidden_size)
        return out, h_n
        
    
    def init_hidden(self, batch_size):
        return torch.randn(self.num_layers, batch_size, self.hidden_size).cuda()

In [None]:
class StatePredictor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(StatePredictor, self).__init__()
        
        self.pred = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(inplace=True),
            nn.Linear(output_dim, output_dim)
        )
    
    def forward(self, x):
        return self.pred(x)

# The Sleep Contrast Model

In [None]:
class SleepContrast(nn.Module):
    def __init__(self, input_channels, hidden_channels, feature_dim, pred_steps, num_seq, batch_size, kernel_sizes):
        super(SleepContrast, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.feature_dim = feature_dim
        self.pred_steps = pred_steps
        self.num_seq = num_seq
        self.batch_size = batch_size
        self.kernel_sizes = kernel_sizes
        
        self.targets = None
        
        # Local Encoder
        self.encoder = ResNet(input_channels, hidden_channels, feature_dim, kernel_sizes=kernel_sizes)
        
        # Memory bank
#         memory_bank = torch.randn(total_size, output_length) 
#         self.register_buffer('memory_bank', memory_bank)
        
        # Aggregator
        self.gru = GRU(input_size=feature_dim, hidden_size=feature_dim, num_layers=2)
        
        # Predictor
        self.predictor = StatePredictor(input_dim=feature_dim, output_dim=feature_dim)
        
#     def _initialize_weights(self, module):
#         for name, param in module.named_parameters():
#             if 'bias' in name:
#                 nn.init.constant_(param, 0.0)
#             elif 'weight' in name:
#                 nn.init.orthogonal_(param, 0.1)

    def compute_targets(self, recompute=False):
        if recompute or self.targets is None:
            self.targets = torch.zeros(self.batch_size, self.pred_steps, self.num_seq, self.batch_size).long()
            for i in range(self.batch_size):
                for j in range(self.pred_steps):
                    self.targets[i, j, self.num_seq-self.pred_steps+j, i] = 1
                    
            self.targets = self.targets.cuda()
            self.targets = self.targets.view(self.batch_size*self.pred_steps, self.num_seq*self.batch_size)
            self.targets = self.targets.argmax(dim=1)
            return self.targets
        else:
            return self.targets
        
    def forward(self, x):
        # Extract feautres
        # x: (batch, num_seq, channel, seq_len)
        (batch, num_seq, channel, seq_len) = x.shape
        x = x.view(batch*num_seq, channel, seq_len)
        feature = self.encoder(x)
        feature = feature.view(batch, num_seq, self.feature_dim) # (batch, num_seq, feature_dim)
        
        # Get context feature
        h_0 = self.gru.init_hidden(self.batch_size)
        # out: (batch, num_seq, hidden_size)
        # h_n: (num_layers, batch, hidden_size)
        out, h_n = self.gru(feature[:, :-self.pred_steps,:], h_0)
        
        # Get predictions
        pred = []
        h_next = h_n
        c_next = out[:,-1,:].squeeze(1)
        for i in range(self.pred_steps):
            z_pred = self.predictor(c_next)
            pred.append(z_pred)
            c_next, h_next = self.gru(z_pred.unsqueeze(1), h_next)
            c_next = c_next[:,-1,:].squeeze(1)
        pred = torch.stack(pred, 1) # (batch, pred_step, feature_dim)
        
        # Compute scores
        feature = feature.transpose(0, 2).contiguous() # (feature_dim, num_seq, batch)
        pred = pred.contiguous()
        
        score = torch.einsum('ijk,kmn->ijmn', [pred, feature]) # (batch, pred_step, num_seq, batch)
        score = score.view(batch*self.pred_steps, num_seq*batch)
        
        return score

# Self-supervised Pre-training

In [None]:
dataset = SleepEDFDataset(path='../input/sleepedf-lite-0', seq_len=args.seq_len, 
                          stride=args.stride, return_label=True)

In [None]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [None]:
train_size = int(len(dataset)*args.train_ratio)
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset)-train_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 
                          drop_last=True, shuffle=True, pin_memory=True)

In [None]:
model = SleepContrast(input_channels=args.input_channels, hidden_channels=args.hidden_channels, 
                      feature_dim=args.feature_dim, pred_steps=args.pred_steps, 
                      batch_size=args.batch_size, num_seq=args.num_seq, kernel_sizes=[7, 11, 7])

In [None]:
model = model.cuda()

In [None]:
optimizer = optim.Adam(model.parameters(), 
                       lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-09, 
                       weight_decay=1e-4, amsgrad=True)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
targets = model.compute_targets()

In [None]:
model.train()
for epoch in range(args.epochs):
    acc_list = []
    loss_list = []
    
    for x, y in track(train_loader, description=f'EPOCH: [{epoch+1}/{args.epochs}]'):
        x, y = x.cuda(), y.cuda()
        
        optimizer.zero_grad()
        score = model(x)
        loss = criterion(score, targets)
        
        loss.backward()
        optimizer.step()
        
#         acc_list.append(acc.item())
        loss_list.append(loss.item())
        
#         progress_bar.set_postfix({'loss': np.mean(loss_list)})
    
    print(f'Loss: {np.mean(loss_list)}')

    if (epoch+1) % 10 == 0:
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
        torch.save(model.state_dict(), os.path.join(args.save_path, f'model_epoch_{epoch}.pth'))

# Fine-tuning

In [None]:
class SleepClassifier(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_classes, feature_dim, pred_steps, num_seq, batch_size, kernel_sizes):
        super(SleepClassifier, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        self.pred_steps = pred_steps
        self.num_seq = num_seq
        self.batch_size = batch_size
        self.kernel_sizes = kernel_sizes
        
        
        # Local Encoder
        self.encoder = ResNet(input_channels, hidden_channels, feature_dim, kernel_sizes=kernel_sizes)
        
        # Aggregator
        self.gru = GRU(input_size=feature_dim, hidden_size=feature_dim, num_layers=2)
        
        # Classifier
        self.relu = nn.ReLU()
        self.mlp = nn.Linear(feature_dim, num_classes)
        
    def freeze_parameters(self):
        for p in self.encoder.parameters():
            p.requires_grad = False
        for p in self.gru.parameters():
            p.requires_grad = False
        
    def forward(self, x):
        # Extract feautres
        # x: (batch, num_seq, channel, seq_len)
        (batch, num_seq, channel, seq_len) = x.shape
        x = x.view(batch*num_seq, channel, seq_len)
        feature = self.encoder(x)
        feature = feature.view(batch, num_seq, self.feature_dim) # (batch, num_seq, feature_dim)
        
        # Get context feature
        h_0 = self.gru.init_hidden(self.batch_size)
        # context: (batch, num_seq, hidden_size)
        # h_n:     (num_layers, batch, hidden_size)
        context, h_n = self.gru(feature[:, :-self.pred_steps,:], h_0)
        
        context = context[:, -1, :]
        out = self.relu(context)
        out = self.mlp(out)
        
        return out

In [None]:
classifier = SleepClassifier(input_channels=args.input_channels, hidden_channels=args.hidden_channels, 
                             num_classes=args.num_classes, feature_dim=args.feature_dim, 
                             pred_steps=args.pred_steps, batch_size=args.batch_size, 
                             num_seq=args.num_seq, kernel_sizes=[7, 11, 7])

In [None]:
classifier = classifier.cuda()

In [None]:
# Copying encoder params
for finetune_param, pretraining_param in zip(classifier.encoder.parameters(), model.encoder.parameters()):
    finetune_param.data = pretraining_param.data

In [None]:
# Copying gru params
for finetune_param, pretraining_param in zip(classifier.gru.parameters(), model.gru.parameters()):
    finetune_param.data = pretraining_param.data

In [None]:
finetune_size = int(len(train_dataset)*args.finetune_ratio)
finetune_dataset, _ = random_split(train_dataset, [finetune_size, len(train_dataset)-finetune_size])

In [None]:
finetune_loader = DataLoader(finetune_dataset, batch_size=args.batch_size, 
                             drop_last=True, shuffle=True, pin_memory=True)

In [None]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), 
                       lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-09, 
                       weight_decay=1e-4, amsgrad=True)
criterion = nn.CrossEntropyLoss()

In [None]:
classifier.train()

for epoch in range(args.finetune_epochs):
    for x, y in track(finetune_loader):
        x, y = x.cuda(), y.cuda()
            
        optimizer.zero_grad()
        y_hat = classifier(x)
        print(y_hat.shape)
        print(y.shape)
        loss = criterion(y_hat, y[-1])
            
        loss.backward()
        optimizer.step()

# Evaluation

In [None]:
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, 
                          drop_last=True, shuffle=True, pin_memory=True)

In [None]:
classifier.eval()

predictions = []
labels = []
for x, y in track(test_loader):
    x, y = x.cuda(), y.cuda()
    
    with torch.no_grad():
        y_hat = classifier(x)
        
    labels.append(y.cpu().numpy())
    predictions.append(y_hat.cpu().numpy())

In [None]:
labels = np.concatenate(labels, axis=0)
predictions = np.concatenate(predictions, axis=0)

In [None]:
predictions = np.argmax(predictions, axis=1)

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
accuracy = accuracy_score(labels, predictions)