In [None]:
import torch

BATCH_SIZE = 64
EPOCHS = 10

In [None]:
from pytorch_lightning.core import LightningModule, LightningDataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader, Dataset
from torchvision.models import mobilenetv2
from torchvision import transforms
import pytorch_lightning as pl
import pandas as pd

class TrainDataset(Dataset):
    def __init__(self, df: pd.DataFrame) -> None:
        super().__init__()
        self.df = df
        self.preprocess = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        line = self.df.iloc[index].tolist()
        image = torch.tensor([int(s) for s in line[-1].split()]).reshape((96, 96))
        image = torch.stack([image, image, image], dim=0)
        # scale to [0, 1]
        image = image / 255
        # image = self.preprocess(image)
        label = torch.tensor(line[:-1], dtype=torch.float32).flatten()
        label = label / 255
        # torch fill nan with -1
        label = torch.where(torch.isnan(label), torch.full_like(label, -1), label)
        return image, label
    
class TestDataset(Dataset):
    def __init__(self, path: str) -> None:
        super().__init__()
        self.df = pd.read_csv(path)
        self.preprocess = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        line = self.df.iloc[index].tolist()
        image = torch.tensor([int(s) for s in line[-1].split()]).reshape((96, 96))
        image = torch.stack([image, image, image], dim=0)
        # scale to [0, 1]
        image = image / 255
        # image = self.preprocess(image)
        return image

class MyDataModule(LightningDataModule):
    def __init__(self, path: str) -> None:
        super().__init__()
        self.path = path
        self.prepare_data()

    def prepare_data(self):
        df = pd.read_csv(self.path + 'train.csv')
        df_train, df_val = df.iloc[:6000], df.iloc[6000:]
        self.train_dataset = TrainDataset(df_train)
        self.val_dataset = TrainDataset(df_val)
        self.test_dataset = TestDataset(self.path + 'test.csv')

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, shuffle=False)

class MyModel(LightningModule):
    def __init__(self, lr: float) -> None:
        super().__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
        self.model.classifier = torch.nn.Linear(1280, 30)
        self.loss = torch.nn.MSELoss()
        self.lr = lr
        self.train_loss = []

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.train_loss.append(loss)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)

    def predict_step(self, batch, batch_idx):
        x = batch
        y_hat = self(x)
        return y_hat * 255
    
    def on_train_epoch_end(self):
        avg_loss = torch.stack(self.train_loss).mean()
        self.log('avg_train_loss', avg_loss, prog_bar=True)
        self.train_loss = []


In [None]:

# class MobileNetV2(torch.nn.Module):
#     def __conv_bn(self, inp, oup, stride):
#         return torch.nn.Sequential(
#             torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
#             torch.nn.BatchNorm2d(oup),
#             torch.nn.ReLU6(inplace=True)
#         )

#     def __conv_1x1_bn(self, inp, oup):
#         return torch.nn.Sequential(
#             torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
#             torch.nn.BatchNorm2d(oup),
#             torch.nn.ReLU6(inplace=True)
#         )

#     def __make_divisible(self, x, divisible_by=8):
#         import numpy as np
#         return int(np.ceil(x * 1. / divisible_by) * divisible_by)

#     class InvertedResidual(torch.nn.Module):
#         def __init__(self, inp, oup, stride, expand_ratio):
#             # super(InvertedResidual, self).__init__()
#             self.stride = stride
#             assert stride in [1, 2]

#             hidden_dim = int(inp * expand_ratio)
#             self.use_res_connect = self.stride == 1 and inp == oup

#             if expand_ratio == 1:
#                 self.conv = torch.nn.Sequential(
#                     # dw
#                     torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
#                     torch.nn.BatchNorm2d(hidden_dim),
#                     torch.nn.ReLU6(inplace=True),
#                     # pw-linear
#                     torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
#                     torch.nn.BatchNorm2d(oup),
#                 )
#             else:
#                 self.conv = torch.nn.Sequential(
#                     # pw
#                     torch.nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
#                     torch.nn.BatchNorm2d(hidden_dim),
#                     torch.nn.ReLU6(inplace=True),
#                     # dw
#                     torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
#                     torch.nn.BatchNorm2d(hidden_dim),
#                     torch.nn.ReLU6(inplace=True),
#                     # pw-linear
#                     torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
#                     torch.nn.BatchNorm2d(oup),
#                 )

#         def forward(self, x):
#             if self.use_res_connect:
#                 return x + self.conv(x)
#             else:
#                 return self.conv(x)


#     def __init__(self, n_class=1000, input_size=224, width_mult=1.):
#         super(MobileNetV2, self).__init__()
#         block = self.InvertedResidual
#         input_channel = 32
#         last_channel = 1280
#         interverted_residual_setting = [
#             # t, c, n, s
#             [1, 16, 1, 1],
#             [6, 24, 2, 2],
#             [6, 32, 3, 2],
#             [6, 64, 4, 2],
#             [6, 96, 3, 1],
#             [6, 160, 3, 2],
#             [6, 320, 1, 1],
#         ]

#         # building first layer
#         assert input_size % 32 == 0
#         # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
#         self.last_channel = self.__make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
#         self.features = [self.__conv_bn(3, input_channel, 2)]
#         # building inverted residual blocks
#         for t, c, n, s in interverted_residual_setting:
#             output_channel = self.__make_divisible(c * width_mult) if t > 1 else c
#             for i in range(n):
#                 if i == 0:
#                     self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
#                 else:
#                     self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
#                 input_channel = output_channel
#         # building last several layers
#         self.features.append(self.__conv_1x1_bn(input_channel, self.last_channel))
#         # make it nn.Sequential
#         self.features = torch.nn.Sequential(*self.features)

#         # building classifier
#         self.classifier = torch.nn.Linear(self.last_channel, n_class)

#         self._initialize_weights()

#     def forward(self, x):
#         x = self.features(x)
#         x = x.mean(3).mean(2)
#         x = self.classifier(x)
#         return x

#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, torch.nn.Conv2d):
#                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#                 m.weight.data.normal_(0, math.sqrt(2. / n))
#                 if m.bias is not None:
#                     m.bias.data.zero_()
#             elif isinstance(m, torch.nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()
#             elif isinstance(m, torch.nn.Linear):
#                 n = m.weight.size(1)
#                 m.weight.data.normal_(0, 0.01)
#                 m.bias.data.zero_()


In [None]:
trainer = pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=1, callbacks=[EarlyStopping(monitor="val_loss")])
model = MyModel(lr=0.0005)
model.load_from_checkpoint('./lightning_logs/version_3/checkpoints/epoch=9-step=940.ckpt', lr=0.0005)
dm = MyDataModule(path='./data/300w/')

In [None]:
trainer.fit(model, dm)

In [None]:
trainer.predict(model, dataloaders=dm.test_dataloader())