In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from pdb import set_trace as breakpoint
from time import time

import numpy as np
import matplotlib.pyplot as plt

from datasets.generators1d import * 
from datasets.settings import OMEGA, MU0, EPSILON0, SCALE, C, L0, PIXEL_SIZE
from models.utils import pbar, tensor_diff, tensor_roll

In [4]:
device = torch.device('cuda:0')

In [None]:
class MaxwellDense(nn.Module):
    
    def __init__(self, size=64, src_x=32, supervised=False, drop_p=0.1):
        super().__init__()
        
        self.size = size
        self.src_x = src_x
        self.supervised = supervised
        self.drop_p = drop_p

        self.layer_dims = [self.size, 128, 256, 256, 256, 128, self.size]

        layers_amp = []
        layers_phi = []
        for i in range(len(self.layer_dims) - 1):
            layers_amp.append(nn.Linear(self.layer_dims[i], self.layer_dims[i+1]))
            layers_phi.append(nn.Linear(self.layer_dims[i], self.layer_dims[i+1]))
            
        self.layers_amp = nn.ModuleList(layers_amp)
        self.layers_phi = nn.ModuleList(layers_phi)
        
    def forward_amplitude_phase(self, x):
        A = x
        imax = len(self.layers_amp) - 1
        for i, layer in enumerate(self.layers_amp):
            A = layer(A)
            if i < imax:
                A = nn.ReLU()(A)
                A = nn.Dropout(p=self.drop_p)(A)
            else:
                A = nn.ELU()(A) + 1 + 0.1
                
        phi = x
        imax = len(self.layers_phi) - 1
        for i, layer in enumerate(self.layers_phi):
            phi = layer(phi)
            if i < imax:
                phi = nn.ReLU()(phi)
                phi = nn.Dropout(p=0.05)(phi)
            else:
                phi = 2 * np.pi * nn.Tanh()(phi)
        
        return A, phi
    
    def get_fields(self, layer_eps, layer_sizes, add_zero_bc=False):
        # Get amplitude and phase vectors
        layer_A, layer_phi = self.forward_amplitude_phase(epsilons)

        # Combine to form waveform
        x = (PIXEL_SIZE * (torch.arange(self.size, dtype=torch.float, device=device) - self.src_x))
        
        fields = torch.zeros_like(x)
        layer_start = 0
        for i in range(len(layer_eps)):
            A = layer_A[:,i]
            phi = layer_phi[:,i]
            fields[layer_start]
        
        A * torch.cos(OMEGA / C * torch.sqrt(epsilons) * x + phi)
        
        if add_zero_bc:
            batch_size, _ = epsilons.shape
            zero = torch.zeros((batch_size, 1), device=device)
            fields = torch.cat([zero, fields, zero], dim=-1)
            
        return fields
    
    def forward(self, layer_eps, layer_sizes):    
        
        # Compute Ez fields
        A, phi = self.forward_amplitude_phase(layer_eps, layer_sizes)
        fields = self.get_fields(layer_eps, layer_sizes)
        
        batch_size, _ = epsilons.shape

        # Add zero field amplitudes at edge points for resonator BC's
        zero = torch.zeros((batch_size, 1), device=device)
        E = torch.cat([zero, fields, zero], dim=-1)

        # Add first layer of cavity BC's
        barrier = torch.full((batch_size, 1), -1e10, device=device)
        eps = torch.cat([barrier, epsilons, barrier], dim=-1)

        # Compute Maxwell operator on fields
        diffs = tensor_diff(E, n=2, padding=None)
        curl_curl_E = (SCALE / PIXEL_SIZE**2) * torch.cat([zero, diffs, zero], dim=-1)
        epsilon_E = (SCALE * -OMEGA**2 * MU0 * EPSILON0) * eps * E

        # Compute free-current vector
        J = torch.zeros_like(E)
        J[:,self.src_x + 1] = 1.526814027933079

        out = curl_curl_E - epsilon_E - J

        # Penalize excessive variation in A/phi
        if self.regularize_A_phi:
            A_variation = torch.sum(torch.abs(tensor_diff(A)), -1, keepdim=True)
            phi_variation = torch.sum(torch.abs(tensor_diff(phi)), -1, keepdim=True)
            eps_variation = torch.sum(torch.abs(tensor_diff(epsilons)), -1, keepdim=True)
            factor = 1e-1 / (1 + eps_variation)
            out = torch.cat([out, factor*A_variation, factor*phi_variation], dim=-1)

        REMOVE_ENDS = True
        if REMOVE_ENDS:
            out = out[:, 1:-1]

        return out

In [5]:
a = torch.tensor([1,2,3])

In [7]:
for i in a:
    print(i)

tensor(1)
tensor(2)
tensor(3)


In [8]:
a[0]

tensor(1)