In [None]:
import torch
import torchvision
from torch import nn
class one2three(nn.Module):
    def __init__(self):
        super(one2three,self).__init__()
        self.upscale  = nn.Upsample(size=(500,500),mode='bilinear',align_corners=True)
        self.backbone = torchvision.models.mobilenet_v3_small(pretrained= True)
        
        for name,param in self.backbone.named_parameters():
            if 'classifier' not in name:
                param.requires_grad = False
                
        self.backbone.classifier[3] = nn.Linear(1024,10)
        self.softmax  = nn.Softmax(dim=1)
    def forward(self,x):
        x = torch.cat([x]*3,dim=1)
        x = self.upscale(x)
        x = self.backbone(x)
        x = self.softmax(x)
        return x

import pytorch_lightning as pl
from random import random
class dark(pl.LightningModule):
    def __init__(self, classes=10):
        super().__init__()
        self.save_hyperparameters()

        self.model = one2three()
        self.loss  = nn.CrossEntropyLoss()
    def forward(self,x):
        x = self.model(x)
        return x
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        tensorboard_logs = {'train_loss': loss.item()}
        
        if batch_idx %10 ==0:
            l = loss.item() 
            self.log("killer",  {"acc": l, "recall": l + random(),'F1':l+random() })
            self.log("killer/child0",  {"acc": l, "recall": l + random(),'F1':l+random() })
        return {'loss': loss , 'log': tensorboard_logs}
    def training_epoch_end(self,training_step_outputs):
        for out in training_step_outputs:
            l = out['loss']
            self.log('epoch_test',{'loss':l,'kk':l+random()})
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

model = dark()

from pytorch_lightning import Trainer, seed_everything
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import os
seed_everything(0)

# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
mnist_train = DataLoader(mnist_train, batch_size=320, num_workers=12)
mnist_val = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
mnist_val = DataLoader(mnist_val, batch_size=320, num_workers=12)

from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger("/aiv-data/tfboard")

trainer = pl.Trainer(gpus=1,logger=[tb_logger],progress_bar_refresh_rate=20,max_epochs=10 ) # gpus=[0,1],accelerator='ddp')
trainer.fit(model,mnist_train,mnist_val)