In [None]:
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from src.data.functions import fetch_and_load_covtype_dataset
from src.models.deep_linear import DeepLinear
from src.models.lightning_wrapper import LightningWrapper
from src.models.simple_logistic_surrogate import SimpleLogisticSurrogate
from torch.utils.data import DataLoader

# Loading data

In [None]:
train_set, test_set = fetch_and_load_covtype_dataset()

train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=True)

# Training root model

In [None]:
model = DeepLinear(train_set.n_features, train_set.n_classes)

loss_function = nn.CrossEntropyLoss()
network = LightningWrapper(model, loss_function)

trainer = pl.Trainer(max_epochs=100, accelerator='cpu', callbacks=[EarlyStopping(monitor="valid_loss", mode="min")])
trainer.fit(network, train_dataloaders=train_loader)

In [None]:
surrogate_model = SimpleLogisticSurrogate(train_set.n_features, train_set.n_classes)