In [0]:
import configparser
import pathlib as path
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from idao.data_module import IDAODataModule
from idao.model import SimpleConv
import numpy as np
import torch
import os
import pathlib as path
from PIL import Image
from torchvision.datasets import DatasetFolder
from torch.utils.data import Dataset
import pathlib as path

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F

import configparser
import pathlib as path

import pytorch_lightning as pl
from pytorch_lightning import seed_everything



In [0]:
class IDAODataset(DatasetFolder):
    def name_to_energy(self, name):
        names = os.path.split(name)[-1].split("_")
        idx = [i for i, v in enumerate(names) if v == "keV"][0]
        return torch.tensor(float(names[idx - 1]))

    def name_to_index(self, name):
        return os.path.split(name)[-1].split('.')[0]

    def __getitem__(self, index: int):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target, self.name_to_energy(path), self.name_to_index(path)

class InferenceDataset(Dataset):
    def __init__(self, main_dir, transform, loader=None):
        self.img_loaderj= img_loader
        self.main_dir = path.Path(main_dir)
        self.transform = transform
        self.all_imgs = list(self.main_dir.glob("*.png"))
        self.loader = loader

    def __len__(self):
        return len(self.all_imgs)

    def __getitem__(self, idx):
        img_loc = self.all_imgs[idx]
        image = self.loader(img_loc)
        tensor_image = self.transform(image)
        return tensor_image, img_loc.name

def img_loader(path: str):
    with Image.open(path) as img:
        img = np.array(img)
    return img


In [0]:
class IDAODataModule(pl.LightningDataModule):
    def __init__(self, data_dir: path.Path, batch_size: int, cfg):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.cfg = cfg

    def prepare_data(self):
        # called only on 1 GPU
        self.dataset = IDAODataset(
            root=self.data_dir.joinpath("train"),
            loader=img_loader,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.CenterCrop(120)]  #CenterCrop(120)
            ),
            # TODO(kazeevn) use idiomatic torch
            target_transform=transforms.Compose(
                [
                    lambda num: (
                        torch.tensor([0, 1]) if num == 0 else torch.tensor([1, 0])
                    )
                ]
            ),
            extensions=self.cfg["DATA"]["Extension"],
        )

        self.test = InferenceDataset(
                    main_dir=self.data_dir.joinpath("test"),
                    loader=img_loader,
                    transform=transforms.Compose(
                        [transforms.ToTensor(), transforms.CenterCrop(120)]#CenterCrop(120)
                    ),
                )


    def setup(self, stage=None):
        # called on every GPU
        self.train, self.val = random_split(
            self.dataset, [10000, 3404], generator=torch.Generator().manual_seed(666)
        )

    def train_dataloader(self):
        return DataLoader(self.train, self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, 1, num_workers=3, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(
            self.test,
            self.batch_size,
            num_workers=0,
            shuffle=False
            )



In [0]:
class Print(nn.Module):
    """Debugging only"""

    def forward(self, x):
        print(x.size())
        return x


class Clamp(nn.Module):
    """Clamp energy output"""

    def forward(self, x):
        x = torch.clamp(x, min=0, max=30)
        return x


class SimpleConv(pl.LightningModule):
    def __init__(self, mode: ["classification", "regression"] = "classification"):
        super().__init__()
        self.mode = mode
        
        self.layer1 = nn.Sequential(
                    nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                    nn.BatchNorm2d(16),
                    nn.ReLU(),
                    nn.MaxPool2d(6),
                    nn.Conv2d(16,32,4,stride=1,padding=1),
                    nn.BatchNorm2d(32),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size=6,stride=3),
                    nn.Flatten(),
                )
        
        
        #self.layer1 = nn.Sequential(
        #            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
        #            nn.BatchNorm2d(16),
        #            nn.ReLU(),
        #            nn.MaxPool2d(kernel_size=19, stride=7),
        #            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
        #            nn.BatchNorm2d(32),
        #            nn.ReLU(),
        #            nn.MaxPool2d(kernel_size=3,stride=3),
        #            nn.Flatten(),
        #        )
        


        self.drop_out = nn.Dropout(p=0.4)

        self.fc1 = nn.Linear(800,200) 
        self.fc2 = nn.Linear(200, 2)  # for classification
        self.fc3 = nn.Linear(200, 1)  # for regression
        #self.fc4 = nn.Linear(400,1) # changed by me

        self.stem = nn.Sequential(
            self.layer1, self.drop_out, self.fc1,
            )
        if self.mode == "classification":
            self.classification = nn.Sequential(self.stem, self.fc2)
        else:
            self.regression = nn.Sequential(self.stem, self.fc3)

        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def training_step(self, batch, batch_idx):
        # --------------------------
        x_target, class_target, reg_target, _ = batch
        if self.mode == "classification":
            class_pred = self.classification(x_target.float())
            class_loss = F.binary_cross_entropy_with_logits(
                class_pred, class_target.float()
            )
            self.train_acc(torch.sigmoid(class_pred), class_target)
            self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
            self.log("classification_loss", class_loss)

            return class_loss

        else:
            reg_pred = self.regression(x_target.float())
            #             reg_loss = F.l1_loss(reg_pred, reg_target.float().view(-1, 1))
            reg_loss = F.mse_loss(reg_pred, reg_target.float().view(-1, 1))

            #             reg_loss = torch.sum(torch.abs(reg_pred - reg_target.float().view(-1, 1)) / reg_target.float().view(-1, 1))
            self.log("regression_loss", reg_loss)
            return reg_loss

    def training_epoch_end(self, outs):
        # log epoch metric
        if self.mode == "classification":
            self.log("train_acc_epoch", self.train_acc.compute())
        else:
            pass

    def validation_step(self, batch, batch_idx):
        x_target, class_target, reg_target, _ = batch
        if self.mode == "classification":
            class_pred = self.classification(x_target.float())
            class_loss = F.binary_cross_entropy_with_logits(
                class_pred, class_target.float()
            )
            self.valid_acc(torch.sigmoid(class_pred), class_target)
            self.log("valid_acc", self.valid_acc.compute())
            self.log("classification_loss", class_loss)
            return class_loss

        else:
            reg_pred = self.regression(x_target.float())
            #             reg_loss = F.l1_loss(reg_pred, reg_target.float().view(-1, 1))
            reg_loss = F.mse_loss(reg_pred, reg_target.float().view(-1, 1))

            #             reg_loss = torch.sum(torch.abs(reg_pred - reg_target.float().view(-1, 1)) / reg_target.float().view(-1, 1))
            self.log("regression_loss", reg_loss)
            return reg_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adamax(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        if self.mode == "classification":
            class_pred = self.classification(x.float())
            return {"class": torch.sigmoid(class_pred)}
        else:
            reg_pred = self.regression(x.float())
            return {"energy": reg_pred}


In [0]:
DM_model = SimpleConv()

In [0]:
from torchinfo import summary
summary(DM_model, input_size=(1,1,120,120))

In [0]:
def trainer(mode: ["classification", "regression"], cfg, dataset_dm):
    model = SimpleConv(mode=mode)
    if mode == "classification":
        epochs = cfg["TRAINING"]["ClassificationEpochs"]
    else:
        epochs = cfg["TRAINING"]["RegressionEpochs"]
    trainer = pl.Trainer(
        gpus=int(cfg["TRAINING"]["NumGPUs"]),
        max_epochs=int(epochs),
        progress_bar_refresh_rate=20,
        weights_save_path=path.Path(cfg["TRAINING"]["ModelParamsSavePath"]).joinpath(
            mode
        ),
        default_root_dir=path.Path(cfg["TRAINING"]["ModelParamsSavePath"]),
    )

    # Train the model ⚡
    trainer.fit(model, dataset_dm)


def main():
    seed_everything(666)
    config = configparser.ConfigParser()
    config.read("./config.ini")

    PATH = path.Path(config["DATA"]["DatasetPath"])

    dataset_dm = IDAODataModule(
        data_dir=PATH, batch_size=int(config["TRAINING"]["BatchSize"]), cfg=config
    )
    dataset_dm.prepare_data()
    dataset_dm.setup()


    for mode in ["classification", "regression"]:
        print(f"Training for {mode}")
        trainer(mode, cfg=config, dataset_dm=dataset_dm)


In [0]:
if __name__ == "__main__":
    main()