In [8]:
import os 
import torch 
from torch import nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split 
from torchvision import transforms, models
import pytorch_lightning as pl 
from utils import PyTorchSatellitePoseEstimationDataset

In [9]:
ROOT_DIR = '/home/salem/Documents/DLR/challenge/speed'

In [14]:
class SatellitePoseEstimationModel(pl.LightningModule):
    def __init__(self): 
        super().__init__()
        initialized_model = models.resnet18(pretrained=True)
        num_ftrs = initialized_model.fc.in_features
        initialized_model.fc = torch.nn.Linear(num_ftrs, 7)
        self.model = initialized_model
    
    def forward(self,x):
        return self.model(x)
        
    def training_step(self,batch ,batch_idx):
        x,y = batch 
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat.float(),y.float())
        self.log("train_loss",loss)
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr = 0.001)
    

In [15]:
class DataModule(pl.LightningDataModule) : 
    def __init__(self, batch_size, speed_root):
        super().__init__()
        self.batch_size = batch_size 
        self.speed_root = speed_root

    def setup(self, stage = None):
        #Transforms 
        data_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                              transforms.ToTensor(),
                                              transforms.Normalize([0.485, 0.456, 0.406], 
                                                                   [0.229, 0.224, 0.225])])
        full_dataset = PyTorchSatellitePoseEstimationDataset('train', self.speed_root, data_transforms)
        if stage in (None, "fit"):
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(full_dataset, 
                                                                   [int(len(full_dataset) * .8),
                                                                    int(len(full_dataset) * .2)])
        if stage == (None, "test"): 
            self.test_dataset = PyTorchSatellitePoseEstimationDataset('test', self.speed_root, data_transforms)
            
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8)
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8)
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = self.batch_size)

In [16]:
model = SatellitePoseEstimationModel()
dm = DataModule(batch_size = 32, speed_root = ROOT_DIR )
trainer = pl.Trainer(gpus = 0 , max_epochs = 10)
trainer.fit(model, dm) 

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.720    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [None]:
!tensorboard --logdir lightning_logs

2021-08-18 18:28:38.816116: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.4.1 at http://localhost:6006/ (Press CTRL+C to quit)
