In [1]:
import os
import numpy as np
import pandas as pd


import matplotlib.pyplot as plt

from pathlib import Path
data_path = Path('/../train/')

In [2]:
# config
config = {
    "data_path": "../",
    "model": {
        "encoder_name": "timm-resnest26d",
        "loss_smooth": 1.0,
        "optimizer_params": {"lr": 0.0035, "weight_decay": 0.0},
        "scheduler": {
            "name": "CosineAnnealingLR",
            "params": {
                "CosineAnnealingLR": {"T_max": 500, "eta_min": 1e-06, "last_epoch": -1},
                "ReduceLROnPlateau": {
                    "factor": 0.35,
                    "mode": "min",
                    "patience": 3,
                    "verbose": True,
                },
            },
        },
        "seg_model": "Unet",
    },
    "output_dir": "models",
    "progress_bar_refresh_rate": 10,
    "seed": 42,
    "train_bs": 128,
    "use_aug": True,
    "trainer": {
        "enable_progress_bar": True,
        "max_epochs": 100,
        "min_epochs": 70,
        "accelerator": "mps",
        "devices": 1,
    },
    "valid_bs": 128,
    "workers": 0,
    "device": "mps",
    "folds": {
        "n_splits": 4,
        "random_state": 42,
        "train_folds": [0, 1, 2, 3]
    }
}


In [3]:
import torch
import numpy as np
import torchvision.transforms as T

import torch
import numpy as np
import torchvision.transforms as T

class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_size=256, train=True, use_augmentations=None):

        self.df = df
        self.trn = train
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        self.use_augmentations = use_augmentations
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        con = np.load(str(con_path))

        img = con[..., :-1]
        label = con[..., -1]

        if self.image_size != 256:
            img = self.resize_image(img)

        if self.use_augmentations:
            transformed = self.use_augmentations(image = img, mask = label)
            img = transformed["image"]
            label = transformed["mask"]

        
        label = torch.tensor(label)

        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)

        img = self.normalize_image(img)

        return img.float(), label.float()

    def __len__(self):
        return len(self.df)


In [4]:
# Lightning module

import torch
import lightning.pytorch as pl
import segmentation_models_pytorch as smp
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.optim import AdamW
import torch.nn as nn
from lightning.pytorch.callbacks import ProgressBar
from torchmetrics.functional import dice, f1_score, jaccard_index
# from torchmetrics import IoU

bar = ProgressBar()

seg_models = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "FPN": smp.FPN,
    "PSPNet": smp.PSPNet,
    "PAN": smp.PAN,
    "DeepLabV3": smp.DeepLabV3,
    "DeepLabV3+": smp.DeepLabV3Plus,
}

class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        logits = torch.sigmoid(logits)

        # flatten label and prediction tensors
        logits = logits.view(-1)
        targets = targets.view(-1)

        intersection = (logits * targets).sum()
        return 1 - (2.0 * intersection + self.smooth) / (
            logits.sum() + targets.sum() + self.smooth
        )


class LightningModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
# Instantiate the base model
        self.model = seg_models[config["seg_model"]](
            encoder_name=config["encoder_name"],
            encoder_weights=None,  # None here because we'll load our own weights
            in_channels=3,
            classes=1,  # Number of classes in your pre-trained model
            activation=None,
        )

        # Load the pre-trained weights into the model, excluding the final layer
        state_dict = torch.load('models/Unet_timm-resnest26d.ckpt')
        # del state_dict['model.classifier.weight']  # Adjust these keys based on your model's structure
        # del state_dict['model.classifier.bias']
        self.model.load_state_dict(state_dict, strict=False)  

        # Replace the final layer to match the number of classes for the new task
        # num_ftrs = self.model.classifier.in_features
        # self.model.classifier = nn.Linear(num_ftrs, 1)

        # Send the model to device
        # self.model = self.model.to(self.device)


        self.loss_module = DiceLoss(smooth=config["loss_smooth"])
        self.val_step_outputs = []
        self.val_step_labels = []

    def forward(self, batch):
        imgs = batch.to(self.device)
        return self.model(imgs)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

        if self.config["scheduler"]["name"] == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(
                optimizer,
                **self.config["scheduler"]["params"][self.config["scheduler"]["name"]],
            )
            lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
        elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                **self.config["scheduler"]["params"][self.config["scheduler"]["name"]],
            )
            lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        labels = labels.unsqueeze(1)
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=16)

        for param_group in self.trainer.optimizers[0].param_groups:
            lr = param_group["lr"]
        self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        labels = labels.unsqueeze(1)
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        # self.val_step_outputs.append(preds)
        # self.val_step_labels.append(labels)

    # def on_validation_epoch_end(self):
    #     all_preds = torch.cat(self.val_step_outputs)
    #     all_labels = torch.cat(self.val_step_labels)
    #     self.val_step_outputs.clear()
    #     self.val_step_labels.clear()
    #     # val_dice = dice(all_preds, all_labels.long())
    #     # val_f1 = f1_score(all_preds.sigmoid(), all_labels.long(), task = "binary")
    #     val_iou = jaccard_index(num_classes=2, task='binary', preds=all_preds.sigmoid(), target=all_labels.long())
    #     self.log("val_iou", val_iou, on_step=False, on_epoch=True, prog_bar=True)

In [5]:
import warnings

warnings.filterwarnings("ignore")

import os
import torch
import pandas as pd
import lightning.pytorch as pl
from pprint import pprint
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
from torch.utils.data import DataLoader
import albumentations as A

contrails = os.path.join(config["data_path"], "contrails/")
train_path = os.path.join(config["data_path"], "train_df.csv")
valid_path = os.path.join(config["data_path"], "valid_df.csv")

train_df = pd.read_csv(train_path)
valid_df = pd.read_csv(valid_path)

train_df["path"] = contrails + train_df["record_id"].astype(str) + ".npy"
valid_df["path"] = contrails + valid_df["record_id"].astype(str) + ".npy"

if config["use_aug"] == True:
    transform_set = A.Compose([A.VerticalFlip(p=0.5),
                            A.HorizontalFlip(p=0.5),
                            A.RandomRotate90(p=0.5)])
else:
    transform_set = None

dataset_train = ContrailsDataset(train_df, train=True, use_augmentations=transform_set)
dataset_validation = ContrailsDataset(valid_df, train=False)

data_loader_train = DataLoader(
    dataset_train, batch_size=config["train_bs"], shuffle=True, num_workers=config["workers"]
)
data_loader_validation = DataLoader(
    dataset_validation, batch_size=config["valid_bs"], shuffle=False, num_workers=config["workers"]
)

pl.seed_everything(config["seed"])

filename = f"{config['model']['seg_model']}_{config['model']['encoder_name']}_contrails"

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=config["output_dir"],
    mode="min",
    filename=filename,
    save_top_k=1,
    verbose=1,
)

progress_bar_callback = TQDMProgressBar(refresh_rate=config["progress_bar_refresh_rate"])

early_stop_callback = EarlyStopping(monitor="val_loss", mode="min", patience=5, verbose=1)

trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback], logger=None, **config["trainer"]
)

model = LightningModule(config["model"])

trainer.fit(model, data_loader_train, data_loader_validation)

Global seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type     | Params
-----------------------------------------
0 | model       | Unet     | 24.0 M
1 | loss_module | DiceLoss | 0     
-----------------------------------------
24.0 M    Trainable params
0         Non-trainable params
24.0 M    Total params
96.134    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.748
Epoch 0, global step 161: 'val_loss' reached 0.74820 (best 0.74820), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.153 >= min_delta = 0.0. New best score: 0.595
Epoch 1, global step 322: 'val_loss' reached 0.59518 (best 0.59518), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 0.567
Epoch 2, global step 483: 'val_loss' reached 0.56675 (best 0.56675), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.564
Epoch 3, global step 644: 'val_loss' reached 0.56426 (best 0.56426), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 805: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 966: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 1127: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.070 >= min_delta = 0.0. New best score: 0.494
Epoch 7, global step 1288: 'val_loss' reached 0.49438 (best 0.49438), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.483
Epoch 8, global step 1449: 'val_loss' reached 0.48285 (best 0.48285), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.474
Epoch 9, global step 1610: 'val_loss' reached 0.47355 (best 0.47355), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 1771: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 1932: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 2093: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 2254: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.463
Epoch 14, global step 2415: 'val_loss' reached 0.46290 (best 0.46290), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.456
Epoch 15, global step 2576: 'val_loss' reached 0.45622 (best 0.45622), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 2737: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 2898: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 3059: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 3220: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.456. Signaling Trainer to stop.
Epoch 20, global step 3381: 'val_loss' was not in top 1
Trainer was signaled to stop but the required `min_epochs=70` or `min_steps=None` has not been met. Training will continue...


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.449
Epoch 21, global step 3542: 'val_loss' reached 0.44869 (best 0.44869), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.448
Epoch 22, global step 3703: 'val_loss' reached 0.44808 (best 0.44808), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 3864: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 4025: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 4186: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.445
Epoch 26, global step 4347: 'val_loss' reached 0.44477 (best 0.44477), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.438
Epoch 27, global step 4508: 'val_loss' reached 0.43787 (best 0.43787), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 4669: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 4830: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 4991: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 31, global step 5152: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.434
Epoch 32, global step 5313: 'val_loss' reached 0.43385 (best 0.43385), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.430
Epoch 33, global step 5474: 'val_loss' reached 0.43018 (best 0.43018), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.425
Epoch 34, global step 5635: 'val_loss' reached 0.42543 (best 0.42543), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 5796: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 5957: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 6118: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 6279: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.423
Epoch 39, global step 6440: 'val_loss' reached 0.42285 (best 0.42285), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 6601: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 41, global step 6762: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 42, global step 6923: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 43, global step 7084: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.423. Signaling Trainer to stop.
Epoch 44, global step 7245: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 6 records. Best score: 0.423. Signaling Trainer to stop.
Epoch 45, global step 7406: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.422
Epoch 46, global step 7567: 'val_loss' reached 0.42171 (best 0.42171), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 47, global step 7728: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 48, global step 7889: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 49, global step 8050: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 50, global step 8211: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.422. Signaling Trainer to stop.
Epoch 51, global step 8372: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.420
Epoch 52, global step 8533: 'val_loss' reached 0.42012 (best 0.42012), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 53, global step 8694: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 54, global step 8855: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 55, global step 9016: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 56, global step 9177: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.420. Signaling Trainer to stop.
Epoch 57, global step 9338: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.412
Epoch 58, global step 9499: 'val_loss' reached 0.41166 (best 0.41166), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 59, global step 9660: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 60, global step 9821: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 61, global step 9982: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 62, global step 10143: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.412. Signaling Trainer to stop.
Epoch 63, global step 10304: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 6 records. Best score: 0.412. Signaling Trainer to stop.
Epoch 64, global step 10465: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 7 records. Best score: 0.412. Signaling Trainer to stop.
Epoch 65, global step 10626: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.409
Epoch 66, global step 10787: 'val_loss' reached 0.40904 (best 0.40904), saving model to '/Users/johnny/Library/CloudStorage/OneDrive-Personal/py/Kaggle/contrails/notebooks/models/Unet_timm-resnest26d_contrails-v5.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 67, global step 10948: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 68, global step 11109: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 69, global step 11270: 'val_loss' was not in top 1


In [6]:
# save the model
torch.save(model.state_dict(), f"models/{config['model']['seg_model']}_{config['model']['encoder_name']}_contrails_70_aug.pt")

In [6]:
%reload_ext tensorboard

In [8]:
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 4011), started 0:00:03 ago. (Use '!kill 4011' to kill it.)

In [None]:
data_df = pd.concat([train_df, valid_df])
data_df = data_df.reset_index(drop=True)
data_df.to_csv("../data_df.csv", index=False)