In [1]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl

# from scint_dataset import ScintillationDataset

In [12]:

import torch
from torch.utils.data import Dataset

import warnings
import numpy
from numpy import array, concatenate, arange, unwrap, angle
from scipy.constants import c, pi
from compact_simulator import simulate_scintillation
fL1 = 1.57542e9
fL2 = 1.2276e9
fL5 = 1.17645e9


class ScintillationDataset(Dataset):

    def __init__(self, length, set_index):
        '''
        '''
        self.length = length
        self.set_index = set_index

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        U, mu0, p1, p2, roveff = 1.2, 1, 2.6, 3.2, 0.9
        freqs = array([fL1, fL2, fL5])
        Nt = 2**12
        Nf = len(freqs)
        dt = 0.01

        alpha = array(list(map(lambda f: 40.308e16 / f**2, freqs)))

        warnings.filterwarnings('ignore')
        psi, phase_screens = simulate_scintillation(U, mu0, p1, p2, roveff, freqs, Nt, dt, self.set_index * self.length + idx)
        TEC = numpy.mean(phase_screens[:, :] * c * freqs[:, None] / 40.308e16 / (2 * pi), axis=0)
        
        phases = array([unwrap(angle(psi[k, :])) for k in range(Nf)])
        amplitudes = array([abs(psi[k, :]) for k in range(Nf)])
        
        x = concatenate((phases, amplitudes), axis=0)
        y = TEC
        return torch.FloatTensor(x), torch.FloatTensor(y)

In [13]:
dataset = ScintillationDataset(100, 0)

In [15]:
x, y = dataset[20001]
print(x.shape, y.shape)

torch.Size([6, 4096]) torch.Size([4096])


In [16]:
class BrianSystem(pl.LightningModule):

    def __init__(self):
        super(BrianSystem, self).__init__()
        # not the best model...
        self.conv1 = torch.nn.Conv1d(6, 64, 5, padding=2)
        self.conv2 = torch.nn.Conv1d(64, 8, 1)
        self.conv3 = torch.nn.Conv1d(8, 64, 5, padding=2)
        self.conv4 = torch.nn.Conv1d(64, 8, 1)
        self.conv5 = torch.nn.Conv1d(8, 64, 5, padding=2)
        self.conv6 = torch.nn.Conv1d(64, 1, 1)

    def forward(self, x):
        out = torch.relu(self.conv1(x))
        out = torch.relu(self.conv2(out))
        out = torch.relu(self.conv3(out))
        out = torch.relu(self.conv4(out))
        out = torch.relu(self.conv5(out))
        out = self.conv6(out)
        return out
    
    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.mse_loss(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(ScintillationDataset(10000, 0), 32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(ScintillationDataset(10000, 1), 32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(ScintillationDataset(10000, 2), 32)

In [17]:
from pytorch_lightning import Trainer

model = BrianSystem()

In [13]:
# most basic trainer, uses good defaults
trainer = Trainer()    
trainer.fit(model)

In [18]:
y_hat = model(x.unsqueeze(0))

In [20]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 3), dpi=200)
ax = fig.add_subplot(111)
ax.plot(y[:])
ax.plot(y_hat.detach().numpy()[0, 0, :])
plt.show()

<Figure size 2000x600 with 1 Axes>