In [103]:
!pip install pytorch-lightning
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import tqdm
import pytorch_lightning as L
PATH = '/kaggle/input/yandex-cup-ml-23-nowcasting/ML Cup 2023 Weather/train/'
from pytorch_lightning import seed_everything
import random
seed=7
seed_everything(seed, workers=True)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True



In [104]:
class SeasonDataset(data.Dataset):

    def __init__(self, list_of_files, season, in_seq_len=4, out_seq_len=12, mode='sequentially', rotate = 0, with_time=False):
        self.in_seq_len = in_seq_len
        self.out_seq_len = out_seq_len
        self.seq_len = in_seq_len + out_seq_len
        self.with_time = with_time
        self.__prepare_timestamps_mapping(list_of_files)
        self.__prepare_sequences(mode)
        self.rotate = rotate
        self.season = season

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        to_append = []
        data = []
        #targets = []
        for timestamp in self.sequences[index]:
            with h5py.File(self.timestamp_to_file[timestamp]) as d:
                #targets.append(np.array([1]))
                
                data.append(np.array([d[timestamp]['intensity']]))

                #data.append(np.array([d[timestamp]['intensity']]))

        data = np.array(data)
        #data = np.expand_dims(data, axis=1) ???
        targets = np.array([self.season])
        data[data == -1e6] = 0
        data[data == -2e6] = -1
        inputs = data[:self.in_seq_len]
        #targets = targets[self.in_seq_len:]
        if self.with_time:
            return (inputs, self.sequences[index][-1]), targets
        else:
            return inputs, targets

    def __prepare_timestamps_mapping(self, list_of_files):
        self.timestamp_to_file = {}
        for filename in list_of_files:
            with h5py.File(filename) as d:
                self.timestamp_to_file = {
                    **self.timestamp_to_file,
                    **dict(map(lambda x: (x, filename), d.keys()))
                }

    def __prepare_sequences(self, mode):
        timestamps = np.unique(sorted(self.timestamp_to_file.keys()))
        if mode == 'sequentially':
            self.sequences = [
                timestamps[index * self.seq_len: (index + 1) * self.seq_len]
                for index in range(len(timestamps) // self.seq_len)
            ]
        elif mode == 'overlap':
            self.sequences = [
                timestamps[index: index + self.seq_len]
                for index in range(len(timestamps) - self.seq_len + 1)
            ]
        else:
            raise Exception(f'Unknown mode {mode}')
        self.sequences = list(filter(
            lambda x: int(x[-1]) - int(x[0]) == (self.seq_len - 1) * 600,
            self.sequences
        ))

In [105]:
def prepare_month_loaders_class(train_batch_size=1):
    train_datasets = []
    val_datasets = []
    for i in range(1, 13):
        month = ''
        if i < 10:
            month = '0' + str(i)
        else:
            month = str(i)
        path = PATH + '2021-' + month + '-train.hdf5'
        month_dataset = SeasonDataset([path], season=0) if int(month) < 5 or int(month) > 8 else SeasonDataset([path], season=1)
        train_month, val_month = torch.utils.data.random_split(month_dataset, [0.8, 0.2])
        train_datasets.append(train_month)
        if int(month) >= 5 and int(month) <= 8:
            train_datasets.append(train_month)
        val_datasets.append(val_month)
    full_train_dataset = torch.utils.data.ConcatDataset([train_datasets[q] for q in range(len(train_datasets))])
    train_loader = data.DataLoader(full_train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
    val_loaders = [data.DataLoader(val_datasets[q], batch_size=1, shuffle=False, num_workers=4) for q in range(len(val_datasets))] 
    return train_loader, val_loaders

def evaluate_classification(model, valid_loader):
    correct_0 = 0
    correct_1 = 0
    incorrect_0 = 0
    incorrect_1 = 0
    default = 0
    for item in tqdm.tqdm(valid_loader):
        inputs, target = item
        inputs = inputs[0, :, 0, :, :]
        inputs = inputs.to(torch.device('cuda'))
        output = model(inputs)
        output = output.detach().cpu().numpy()
        if output > 0.6 and target == 1:
            correct_1 += 1
        elif output < 0.4 and target == 1:
            incorrect_1 += 1
        elif output < 0.4 and target == 0:
            correct_0 += 1
        elif output > 0.6 and target == 0:
            incorrect_0 += 1
        else:
            default += 1
            
    return correct_0, incorrect_0, correct_1, incorrect_1, default

In [106]:
class ClassModel(L.LightningModule):
    def __init__(self, num_kernels):
        super().__init__()
        
        self.activation = torch.relu
        self.conv1 = nn.Conv2d(4, num_kernels, kernel_size=3, padding=1) #32 252 252
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) #96 126 126
        self.conv2 = nn.Conv2d(num_kernels, num_kernels, kernel_size=3, padding=1) #32 126 126
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) #96 63 63
        self.conv3 = nn.Conv2d(num_kernels, num_kernels, kernel_size=3, padding=1) #32 63 63
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) #32 32 32
        self.conv4 = nn.Conv2d(num_kernels, num_kernels, kernel_size=3, padding=1) #32 32 32
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) #32 16 16
        self.conv5 = nn.Conv2d(num_kernels,num_kernels, kernel_size=3, padding=1) # 32 16 16
        self.maxpool5 = nn.MaxPool2d(kernel_size=2, stride=2) #32 8 8
        self.fc1 = nn.Linear(2048, 64)
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, X):
        output = self.sigmoid(self.fc2(self.activation(self.fc1(self.maxpool5(self.activation(self.conv5(self.maxpool4(self.activation(self.conv4(self.maxpool3(self.activation(self.conv3(
                self.maxpool2(self.activation(self.conv2(self.maxpool1(self.activation(self.conv1(X))))))))))))))).view(-1, 2048)))))
        return output
    
    def training_step(self, batch):
        x, y = batch
        x = x[0,:,0, :, :] #after this x.shape = [8, 252, 252]
        out = self.forward(x)
        loss = F.binary_cross_entropy(out.float(), y.float())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)
        return optimizer
    

In [107]:
class_model = ClassModel(32)

In [108]:
class_train, class_val = prepare_month_loaders_class()

In [None]:
class_model.current_lr = 3e-4
trainer = L.Trainer(
            max_epochs=1
        )
trainer.fit(class_model, class_train)

In [None]:
class_model.current_lr = 5e-4
trainer = L.Trainer(
            max_epochs=1
        )
trainer.fit(class_model, class_train)

In [90]:
device = torch.device('cuda')
class_model.to(device)
for i in range(len(class_val)):
    print(evaluate_classification(class_model, class_val[i]))


  0%|          | 0/55 [00:00<?, ?it/s][A
  2%|▏         | 1/55 [00:00<00:32,  1.66it/s][A
  9%|▉         | 5/55 [00:00<00:08,  6.06it/s][A
 16%|█▋        | 9/55 [00:01<00:05,  8.09it/s][A
 24%|██▎       | 13/55 [00:01<00:04,  9.94it/s][A
 31%|███       | 17/55 [00:01<00:03, 11.00it/s][A
 38%|███▊      | 21/55 [00:02<00:02, 12.47it/s][A
 42%|████▏     | 23/55 [00:02<00:02, 12.89it/s][A
 45%|████▌     | 25/55 [00:02<00:02, 11.83it/s][A
 49%|████▉     | 27/55 [00:02<00:02, 12.00it/s][A
 53%|█████▎    | 29/55 [00:02<00:02, 12.35it/s][A
 56%|█████▋    | 31/55 [00:03<00:02, 11.59it/s][A
 60%|██████    | 33/55 [00:03<00:02, 10.96it/s][A
 64%|██████▎   | 35/55 [00:03<00:01, 11.30it/s][A
 67%|██████▋   | 37/55 [00:03<00:01, 11.84it/s][A
 71%|███████   | 39/55 [00:03<00:01, 12.55it/s][A
 75%|███████▍  | 41/55 [00:03<00:01, 11.11it/s][A
 80%|████████  | 44/55 [00:04<00:00, 13.79it/s][A
 84%|████████▎ | 46/55 [00:04<00:00, 10.99it/s][A
 89%|████████▉ | 49/55 [00:04<00:00,  9.01

(46, 1, 0, 0, 8)



  0%|          | 0/49 [00:00<?, ?it/s][A
  2%|▏         | 1/49 [00:00<00:26,  1.81it/s][A
 10%|█         | 5/49 [00:00<00:06,  7.10it/s][A
 18%|█▊        | 9/49 [00:01<00:04,  9.48it/s][A
 27%|██▋       | 13/49 [00:01<00:03, 11.97it/s][A
 35%|███▍      | 17/49 [00:01<00:02, 12.64it/s][A
 43%|████▎     | 21/49 [00:01<00:02, 12.77it/s][A
 51%|█████     | 25/49 [00:02<00:01, 12.93it/s][A
 59%|█████▉    | 29/49 [00:02<00:01, 12.76it/s][A
 67%|██████▋   | 33/49 [00:02<00:01, 13.08it/s][A
 76%|███████▌  | 37/49 [00:03<00:00, 12.92it/s][A
 84%|████████▎ | 41/49 [00:03<00:00, 12.11it/s][A
 92%|█████████▏| 45/49 [00:03<00:00, 13.84it/s][A
100%|██████████| 49/49 [00:04<00:00, 12.20it/s][A


(33, 1, 0, 0, 15)



  0%|          | 0/49 [00:00<?, ?it/s][A
  2%|▏         | 1/49 [00:00<00:23,  2.00it/s][A
 10%|█         | 5/49 [00:00<00:06,  6.77it/s][A
 18%|█▊        | 9/49 [00:01<00:04,  9.87it/s][A
 27%|██▋       | 13/49 [00:01<00:02, 12.06it/s][A
 35%|███▍      | 17/49 [00:01<00:02, 12.21it/s][A
 43%|████▎     | 21/49 [00:01<00:02, 12.56it/s][A
 51%|█████     | 25/49 [00:02<00:01, 13.23it/s][A
 59%|█████▉    | 29/49 [00:02<00:01, 13.32it/s][A
 67%|██████▋   | 33/49 [00:02<00:01, 13.60it/s][A
 76%|███████▌  | 37/49 [00:03<00:00, 14.21it/s][A
 80%|███████▉  | 39/49 [00:03<00:00, 14.99it/s][A
 84%|████████▎ | 41/49 [00:03<00:00, 13.37it/s][A
 92%|█████████▏| 45/49 [00:03<00:00, 14.11it/s][A
100%|██████████| 49/49 [00:03<00:00, 12.42it/s][A


(37, 0, 0, 0, 12)



  0%|          | 0/52 [00:00<?, ?it/s][A
  2%|▏         | 1/52 [00:00<00:30,  1.68it/s][A
  8%|▊         | 4/52 [00:00<00:06,  7.10it/s][A
 12%|█▏        | 6/52 [00:00<00:06,  7.31it/s][A
 17%|█▋        | 9/52 [00:01<00:05,  8.48it/s][A
 25%|██▌       | 13/52 [00:01<00:03, 10.26it/s][A
 33%|███▎      | 17/52 [00:01<00:03, 11.17it/s][A
 40%|████      | 21/52 [00:02<00:02, 11.60it/s][A
 48%|████▊     | 25/52 [00:02<00:02, 12.70it/s][A
 56%|█████▌    | 29/52 [00:02<00:01, 12.38it/s][A
 63%|██████▎   | 33/52 [00:03<00:01, 12.73it/s][A
 71%|███████   | 37/52 [00:03<00:01, 12.89it/s][A
 79%|███████▉  | 41/52 [00:03<00:00, 12.60it/s][A
 87%|████████▋ | 45/52 [00:03<00:00, 13.37it/s][A
100%|██████████| 52/52 [00:04<00:00, 11.75it/s][A


(24, 11, 0, 0, 17)



  0%|          | 0/55 [00:00<?, ?it/s][A
  2%|▏         | 1/55 [00:00<00:32,  1.69it/s][A
  9%|▉         | 5/55 [00:00<00:07,  6.68it/s][A
 13%|█▎        | 7/55 [00:00<00:05,  8.70it/s][A
 16%|█▋        | 9/55 [00:01<00:05,  9.11it/s][A
 20%|██        | 11/55 [00:01<00:04,  9.71it/s][A
 24%|██▎       | 13/55 [00:01<00:04, 10.13it/s][A
 27%|██▋       | 15/55 [00:01<00:03, 11.07it/s][A
 31%|███       | 17/55 [00:01<00:03, 11.73it/s][A
 35%|███▍      | 19/55 [00:01<00:02, 12.61it/s][A
 38%|███▊      | 21/55 [00:02<00:02, 12.58it/s][A
 42%|████▏     | 23/55 [00:02<00:02, 14.13it/s][A
 45%|████▌     | 25/55 [00:02<00:02, 12.74it/s][A
 49%|████▉     | 27/55 [00:02<00:02, 13.43it/s][A
 53%|█████▎    | 29/55 [00:02<00:02, 12.25it/s][A
 56%|█████▋    | 31/55 [00:02<00:01, 13.27it/s][A
 60%|██████    | 33/55 [00:03<00:01, 11.89it/s][A
 64%|██████▎   | 35/55 [00:03<00:01, 12.76it/s][A
 67%|██████▋   | 37/55 [00:03<00:01, 13.20it/s][A
 71%|███████   | 39/55 [00:03<00:01, 12.87i

(0, 0, 24, 10, 21)



  0%|          | 0/53 [00:00<?, ?it/s][A
  2%|▏         | 1/53 [00:00<00:31,  1.65it/s][A
  9%|▉         | 5/53 [00:00<00:06,  7.31it/s][A
 17%|█▋        | 9/53 [00:01<00:04,  9.44it/s][A
 25%|██▍       | 13/53 [00:01<00:03, 10.89it/s][A
 32%|███▏      | 17/53 [00:01<00:03, 11.28it/s][A
 40%|███▉      | 21/53 [00:02<00:02, 11.67it/s][A
 47%|████▋     | 25/53 [00:02<00:02, 11.17it/s][A
 55%|█████▍    | 29/53 [00:02<00:02, 11.78it/s][A
 62%|██████▏   | 33/53 [00:03<00:01, 12.55it/s][A
 70%|██████▉   | 37/53 [00:03<00:01, 13.34it/s][A
 77%|███████▋  | 41/53 [00:03<00:00, 13.38it/s][A
 85%|████████▍ | 45/53 [00:03<00:00, 13.95it/s][A
 91%|█████████ | 48/53 [00:04<00:00, 15.34it/s][A
 94%|█████████▍| 50/53 [00:04<00:00, 15.75it/s][A
100%|██████████| 53/53 [00:04<00:00, 12.01it/s][A


(0, 0, 24, 7, 22)



  0%|          | 0/53 [00:00<?, ?it/s][A
  2%|▏         | 1/53 [00:00<00:30,  1.73it/s][A
  9%|▉         | 5/53 [00:00<00:06,  7.14it/s][A
 17%|█▋        | 9/53 [00:01<00:04,  9.40it/s][A
 25%|██▍       | 13/53 [00:01<00:03, 10.46it/s][A
 32%|███▏      | 17/53 [00:01<00:03, 11.20it/s][A
 40%|███▉      | 21/53 [00:02<00:02, 12.11it/s][A
 47%|████▋     | 25/53 [00:02<00:02, 12.90it/s][A
 55%|█████▍    | 29/53 [00:02<00:01, 12.83it/s][A
 62%|██████▏   | 33/53 [00:02<00:01, 13.31it/s][A
 70%|██████▉   | 37/53 [00:03<00:01, 14.16it/s][A
 77%|███████▋  | 41/53 [00:03<00:00, 14.09it/s][A
 85%|████████▍ | 45/53 [00:03<00:00, 14.03it/s][A
 92%|█████████▏| 49/53 [00:04<00:00, 14.43it/s][A
100%|██████████| 53/53 [00:04<00:00, 12.31it/s][A


(0, 0, 28, 2, 23)



  0%|          | 0/55 [00:00<?, ?it/s][A
  2%|▏         | 1/55 [00:00<00:28,  1.91it/s][A
  9%|▉         | 5/55 [00:00<00:08,  6.21it/s][A
 16%|█▋        | 9/55 [00:01<00:05,  9.12it/s][A
 24%|██▎       | 13/55 [00:01<00:04,  9.26it/s][A
 31%|███       | 17/55 [00:01<00:03, 10.34it/s][A
 38%|███▊      | 21/55 [00:02<00:03, 10.16it/s][A
 45%|████▌     | 25/55 [00:02<00:02, 11.15it/s][A
 53%|█████▎    | 29/55 [00:02<00:02, 10.95it/s][A
 60%|██████    | 33/55 [00:03<00:01, 11.22it/s][A
 67%|██████▋   | 37/55 [00:03<00:01, 12.63it/s][A
 75%|███████▍  | 41/55 [00:03<00:01, 12.67it/s][A
 82%|████████▏ | 45/55 [00:04<00:00, 13.26it/s][A
 89%|████████▉ | 49/55 [00:04<00:00, 12.60it/s][A
100%|██████████| 55/55 [00:04<00:00, 11.41it/s][A


(0, 0, 41, 2, 12)



  0%|          | 0/53 [00:00<?, ?it/s][A
  2%|▏         | 1/53 [00:00<00:25,  2.08it/s][A
  9%|▉         | 5/53 [00:00<00:06,  7.06it/s][A
 17%|█▋        | 9/53 [00:01<00:04, 10.63it/s][A
 23%|██▎       | 12/53 [00:01<00:02, 13.95it/s][A
 26%|██▋       | 14/53 [00:01<00:03, 12.04it/s][A
 32%|███▏      | 17/53 [00:01<00:02, 12.45it/s][A
 40%|███▉      | 21/53 [00:01<00:02, 11.93it/s][A
 47%|████▋     | 25/53 [00:02<00:02, 12.02it/s][A
 55%|█████▍    | 29/53 [00:02<00:01, 12.54it/s][A
 60%|██████    | 32/53 [00:02<00:01, 14.47it/s][A
 64%|██████▍   | 34/53 [00:02<00:01, 12.10it/s][A
 70%|██████▉   | 37/53 [00:03<00:01,  8.84it/s][A
 77%|███████▋  | 41/53 [00:03<00:01, 10.15it/s][A
 85%|████████▍ | 45/53 [00:04<00:00, 11.05it/s][A
 92%|█████████▏| 49/53 [00:04<00:00, 12.45it/s][A
100%|██████████| 53/53 [00:04<00:00, 11.47it/s][A


(21, 6, 0, 0, 26)



  0%|          | 0/52 [00:00<?, ?it/s][A
  2%|▏         | 1/52 [00:00<00:25,  2.04it/s][A
 10%|▉         | 5/52 [00:00<00:06,  7.75it/s][A
 17%|█▋        | 9/52 [00:01<00:04, 10.19it/s][A
 25%|██▌       | 13/52 [00:01<00:03, 12.33it/s][A
 33%|███▎      | 17/52 [00:01<00:02, 13.56it/s][A
 40%|████      | 21/52 [00:01<00:02, 13.31it/s][A
 48%|████▊     | 25/52 [00:02<00:01, 15.01it/s][A
 56%|█████▌    | 29/52 [00:02<00:01, 14.61it/s][A
 63%|██████▎   | 33/52 [00:02<00:01, 13.36it/s][A
 71%|███████   | 37/52 [00:02<00:01, 13.46it/s][A
 79%|███████▉  | 41/52 [00:03<00:00, 13.33it/s][A
 87%|████████▋ | 45/52 [00:03<00:00, 12.73it/s][A
100%|██████████| 52/52 [00:03<00:00, 13.30it/s][A


(21, 4, 0, 0, 27)



  0%|          | 0/52 [00:00<?, ?it/s][A
  2%|▏         | 1/52 [00:00<00:26,  1.95it/s][A
 10%|▉         | 5/52 [00:00<00:06,  7.43it/s][A
 17%|█▋        | 9/52 [00:01<00:04, 10.33it/s][A
 25%|██▌       | 13/52 [00:01<00:03, 11.34it/s][A
 33%|███▎      | 17/52 [00:01<00:02, 12.76it/s][A
 40%|████      | 21/52 [00:01<00:02, 13.69it/s][A
 44%|████▍     | 23/52 [00:01<00:02, 14.41it/s][A
 48%|████▊     | 25/52 [00:02<00:02, 12.71it/s][A
 56%|█████▌    | 29/52 [00:02<00:01, 13.24it/s][A
 63%|██████▎   | 33/52 [00:02<00:01, 14.28it/s][A
 67%|██████▋   | 35/52 [00:02<00:01, 13.86it/s][A
 71%|███████   | 37/52 [00:03<00:01, 13.39it/s][A
 75%|███████▌  | 39/52 [00:03<00:00, 13.21it/s][A
 79%|███████▉  | 41/52 [00:03<00:00, 12.07it/s][A
 83%|████████▎ | 43/52 [00:03<00:00, 11.51it/s][A
 87%|████████▋ | 45/52 [00:03<00:00, 11.06it/s][A
 90%|█████████ | 47/52 [00:03<00:00, 11.74it/s][A
 94%|█████████▍| 49/52 [00:04<00:00, 12.28it/s][A
100%|██████████| 52/52 [00:04<00:00, 12.12

(40, 0, 0, 0, 12)



  0%|          | 0/55 [00:00<?, ?it/s][A
  2%|▏         | 1/55 [00:00<00:26,  2.04it/s][A
  9%|▉         | 5/55 [00:00<00:06,  7.14it/s][A
 16%|█▋        | 9/55 [00:01<00:04,  9.95it/s][A
 24%|██▎       | 13/55 [00:01<00:03, 11.92it/s][A
 31%|███       | 17/55 [00:01<00:02, 12.98it/s][A
 35%|███▍      | 19/55 [00:01<00:02, 13.75it/s][A
 38%|███▊      | 21/55 [00:01<00:02, 12.49it/s][A
 45%|████▌     | 25/55 [00:02<00:02, 13.75it/s][A
 49%|████▉     | 27/55 [00:02<00:01, 14.54it/s][A
 53%|█████▎    | 29/55 [00:02<00:01, 13.91it/s][A
 56%|█████▋    | 31/55 [00:02<00:01, 14.36it/s][A
 60%|██████    | 33/55 [00:02<00:01, 14.10it/s][A
 64%|██████▎   | 35/55 [00:02<00:01, 14.24it/s][A
 67%|██████▋   | 37/55 [00:03<00:01, 12.92it/s][A
 71%|███████   | 39/55 [00:03<00:01, 14.35it/s][A
 75%|███████▍  | 41/55 [00:03<00:01, 12.26it/s][A
 78%|███████▊  | 43/55 [00:03<00:00, 13.65it/s][A
 82%|████████▏ | 45/55 [00:03<00:00, 11.81it/s][A
 85%|████████▌ | 47/55 [00:03<00:00, 12.75

(44, 3, 0, 0, 8)





In [91]:
torch.save(class_model, 'class_model_7e-4.pt')