In [None]:
# Top level imports
from typing import Optional
import pytorch_lightning as pl
import torch
import numpy as np
import numpy.typing as npt
from matplotlib import pyplot as plt

# Machine Learning in Fourier data.

Many properties of periodic systems are analyzed in reciprocal using the Fourier Transform of the data.
The Goal of this notebook is to see if standard convolutional neural networks (CNN) architectures can be used to learn frequency-dependent properties from periodic data.
Since the frequency-dependent properties are usually more decipherable (at least by humans) in reciprocal space, we will also see if examining the Fourier representation of the data in reciprocal space can help the the learning process in any meaningful way.

As a contrived example we will assume each function is made up of a sum of randomly shifted and scaled sine waves periodic in the interval $[0,1]$ 

$$ f(x) = \sum_{n}  c_n  \sin(2 \pi n (x - \mu_{n})) $$

Use randomized periodic data in 1D, we will try to learn a fictitious *energy* function defined as the sum of weighted sum of the sine waves frequencies squared.

$$ E = \sum_{n}  c_n^2  n ^ 2 $$

This is reminiscent of the energy function of a particle in a box, where the energy is proportional to the square of the frequency of the sine wave.
$$ E_n = \frac{\hbar^2 \pi^2 n^2}{2 m L^2} $$

We will see how well a CNN can learn this energy function from the periodic grid data alone and how much adding the Fourier representation of the data helps the process.
Note that our situation is a bit more complex than the particle in a box, since we are allowing the sine waves to be shifted so the boundary condition is not longer constrained to zero at the domain boundaries.


## Randomly generated 1D data

Let's first write a function that creates the $f(x)$ functions above after providing a set of parameters $c_n$ and and the x-shifts $x_{0,n}$.
We can test the output of the function using the following parameters:
$$ c_0 = 1,\, c_1 = 2,\, c_2 = 3 $$
$$ \mu_0 = 0.2,\, \mu_1 = 0.4,\, \mu_2 = 0.6 $$
And plot the functions on top of each other.

In [None]:
def get_trig_function(coeff: list[float], mu: list[float]):
    """Return a trigonometric function of x.

    c0 + c1 * sin(2 * pi * (x - μ_1)) + c2 * sin(4 * pi * (x - μ_2)) + ...

    Parameters
    ----------
    coeff :
        The coefficients c0, c1, c2, ...
    μ_i :
        The μ_i in the above formula.  Note that the first μ_0 is assumed 
        does not affect the output of the function.
    """
    def func(x):
        res = np.zeros_like(x)
        for m, (c, x0n) in enumerate(zip(coeff, mu)):
            res += c * np.cos(2 * np.pi * m * (x - x0n))
        return res
    return func

# Test it in a plot
x = np.linspace(0, 1, 100)
func = get_trig_function([1,2,3], [0.2, 0.4, 0.6])
plt.plot(x, func(x))
plt.plot(x, 1 + 2 * np.cos(2 * np.pi * (x - 0.4)) + 3 * np.cos(4 * np.pi * (x - 0.6)), "--")


In [None]:
# generate a random 1 D function
def gen1D(
    xx: npt.ArrayLike, 
    range_c: list[tuple[float, float]], 
    range_mu: list[tuple[float, float]],
    targe_func: Optional[callable]  = None,
    max_iter=1000):
    """Generate a random 1D function.
    
    Args:
        xx: x values for the grid
        range_c: range of coefficients
        range_mu: range of μ
        targe_func: a function that takes the list of 
            `c` and `mu` parameters and returns a scalar.
        rand_angles: if True, randomize the phase of the FFT 
            for small (< 0.1 % of max) Fourier coefficients.
        max_iter: maximum number of iterations to try to generate a function

    Returns:
        Data: stacked function, fft absolute value, fft phase
        target: energy
    """
    for _ in range(max_iter):
        # generate random parameters
        c = [np.random.uniform(*r) for r in range_c]
        mu = [np.random.uniform(*r) for r in range_mu]
        
        # get the "energy"
        if targe_func is None:
            targe_func = lambda c, mu: np.sum([c * n**2 for n, c in enumerate(c)])
        energy = targe_func(c, mu)
        
        # get the gridded output
        f = get_trig_function(c, mu)
        yy = f(xx)
        fft_yy = np.fft.fftshift(np.fft.fft(yy))
        fft_abs = np.abs(fft_yy)
        fft_arg = np.angle(fft_yy)
        
        mask_small_abs = fft_abs < 1E-3 * np.max(fft_abs)
        fft_arg[mask_small_abs] = np.random.uniform(-np.pi, np.pi, size=np.sum(mask_small_abs))
        # Make sure the output is float32 to make it work with pytorch
        yield np.stack([yy, fft_abs, fft_arg]).astype(np.float32), np.array([energy]).astype(np.float32)

# DataSet 
class DS1D(torch.utils.data.IterableDataset):
    """A 1D dataset."""
    def __init__(self, x_max, gen1d_kwargs: dict = None, **kwargs):
        super().__init__(**kwargs)
        self.x_max = x_max
        self.xx = np.linspace(0, self.x_max, 512)
        self.fft_xx = np.fft.fftshift(np.fft.fftfreq(self.xx.size, self.xx[1] - self.xx[0]))
        self.gen1d_kwargs = gen1d_kwargs or {}
    
    def __iter__(self):
        # return from generator
        yield from gen1D(xx = self.xx, **self.gen1d_kwargs)

range_c = [(-1, 1),(-1, 1),(-1, 1),]
range_mu = [(0.2, 0.5),(0.2, 0.5),(0.2, 0.5),]
train_ds = DS1D(x_max=1, range_c=range_c, range_mu=range_mu, max_iter=10000)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=128)
ex, ey = next(iter(train_dl))


In [None]:
yy  = ex[42, 0, :]
fft_abs = ex[42, 1, :]
fft_arg = ex[42, 2, :]
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 10))
ax1.plot(train_ds.xx, yy); ax1.set_title("yy")
ax2.stem(train_ds.fft_xx, fft_abs); 
ax3.plot(train_ds.fft_xx, fft_arg);
ax1.text(0.9, 0.9, "real_space", transform=ax1.transAxes, ha="center", va="top", fontdict={"size": 20})
ax2.text(0.9, 0.9, "fft_abs", transform=ax2.transAxes, ha="center", va="top", fontdict={"size": 20})
ax3.text(0.9, 0.9, "fft_arg", transform=ax3.transAxes, ha="center", va="top", fontdict={"size": 20})
ax2.set(xlim=(-10, 10)); ax3.set(xlim=(-10, 10))

In [None]:
# Model
class CNN1D(pl.LightningModule):
    def __init__(self, nchan: int):
        super().__init__()
        # use gpus if available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # First set of Conv->ReLU->MaxPool layers
        self.nchan = nchan
        
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv1d(self.nchan, 32, kernel_size=7),
            torch.nn.MaxPool1d(kernel_size=2, stride=2),
        )
        
        # Second set of Conv->ReLU->MaxPool layers
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv1d(32, 32, kernel_size=7, stride=1, padding=1),
            torch.nn.MaxPool1d(kernel_size=2, stride=2)
        )
        
        # Fully connected layers
        self.out = torch.nn.Sequential(
            torch.nn.Linear(3968, 128),
            torch.nn.Linear(128, 64),
            torch.nn.Linear(64, 16),
            torch.nn.Linear(16, 1)
        )

    def forward(self, x_in: torch.Tensor):
        # only use the first channel
        x = x_in[:, 0:self.nchan, :]
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten
        x = torch.flatten(x, 1)
        # output
        output = self.out(x)
        return output

    def configure_optimizers(self):
        """Use Adam optimizer."""
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log("val_loss", loss)
        return {'loss': loss, 'log': {'train_loss': loss}}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log("test_loss", loss)
        return loss

# instantiate model
model = CNN1D(nchan=3)

# A + B * sin(m * x)
train_ds = DS1D(x_max=1, range_c=range_c, range_mu=range_mu, max_iter=1000)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64)
test_ds = DS1D(x_max=1, range_c=range_c, range_mu=range_mu, max_iter=1000)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64, drop_last=True)

trainer = pl.Trainer(max_epochs=100)
trainer.fit(model, train_dl, test_dl);


In [None]:
test_ds = DS1D(x_max=1, range_c=range_c, range_mu=range_mu, max_iter=1000)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=5, drop_last=True)
ex, ey = next(iter(test_dl))


In [None]:
model(ex), ey

In [None]:
plt.plot(ex[1,1,200:400])

In [None]:
jjjjjjj16384 / 512

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, get_worker_info
import math

class MyIterableDataset(IterableDataset):
    '''This dataset is copied from PyTorch docs.'''
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    
    def __iter__(self):
        iter_start = self.start
        iter_end = self.end
        return iter(range(iter_start, iter_end))

ds = MyIterableDataset(0, 10)
dl = DataLoader(train_ds, batch_size=32)
print(len(next(iter(dl)))) # should give 4, but gives 2

In [None]:
[*iter(ds)]

In [None]:
xx = np.linspace(0, 1, 512)
yy = np.linspace(0, 1, 512)
XX, YY = np.meshgrid(xx, yy, indexing="ij")
plt.imshow(f1(XX, YY))


In [None]:
from pyrho.pgrid import PGrid
pgrid = PGrid(grid_data=f1(XX, YY), lattice=[[1, 0], [0, 1]])

In [None]:
pg2 = pgrid.get_transformed(sc_mat=[[2, 0], [0, 2]], origin=(0, 0), grid_out=(512, 512))

In [None]:
# get randomly oriented square lattice
def get_random_square_lattice(side_length, uc_lattice):
    """Return a random square lattice."""
    theta = np.random.uniform(0, 2 * np.pi)
    cube_latt = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) * side_length
    # return the cube lattice in terms of the unit cell lattice
    return np.dot(cube_latt, np.linalg.inv(uc_lattice))



In [None]:
res = get_random_square_lattice(1.5, np.eye(2))
pg2 = pgrid.get_transformed(sc_mat=res, origin=(0, 0), grid_out=(512, 512))
plt.imshow(pg2.grid_data)

In [None]:
# for i in range(5):
#     res = get_random_square_lattice(1.5, np.eye(2))
#     pg2 = pgrid.get_transformed(sc_mat=res, origin=(0, 0), grid_out=(512, 512))
#     res = get_random_square_lattice(2.0, np.eye(2))
#     pg2 = pgrid.get_transformed(sc_mat=res, origin=(0, 0), grid_out=(512, 512))
#     fres = np.abs(np.fft.fft2(pg2.grid_data))
#     fres = np.fft.fftshift(fres)
#     width = 10
#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 7))
#     ax1.imshow(fres[256-width:256+width, 256-width:256+width], vmin=0, vmax=3E4)
#     ax2.imshow(pg2.grid_data)

In [None]:
for batch in dl:
    print(len(batch))
    break

In [None]:
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 7))
# ax1.plot(xx, f)
# ax2.plot(xx_fourier, ff)
# ax2.set_xlim(-15, 15)

In [None]:
# CNN
class CNN1D(torch.nn.Module):
    # define all the layers used in model
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool1d(2, 2)
        self.fc1 = torch.nn.Linear(32 * 256, 128)
        self.fc2 = torch.nn.Linear(128, 3)
        self.dropout = torch.nn.Dropout(0.25)

    # define the forward pass
    def forward(self, x):
        # one conv layer
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        # second conv layer
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        # flatten
        x = x.view(-1, 32 * 256)
        # dropout layer
        x = self.dropout(x)
        # first dense layer
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        # second dense layer
        x = self.fc2(x)
        return x

    # define training step

Bash function that accepts a setting i.e. SETTING=1 and replaces the setting in a given file.
If the setting is not found, it will be appended to the end of the file.
```bash
function set_setting {
    SETTING=$1
    FILE=$2
    if grep -q $SETTING $FILE; then
        sed -i "s/$SETTING.*/$SETTING/" $FILE
    else
        echo $SETTING >> $FILE
    fi
}
```

script that accepts `-s` or `--setting` followed by the file name.
```bash
#!/bin/bash
function set_setting {
    SETTING=$1
    VALUE=$2
    FILE=$3
    if grep -q $SETTING $FILE; then
        sed -i "s/$SETTING.*/$SETTING = $VALUE/" $FILE
    else
        echo $SETTING >> $FILE
    fi
}

SETTING=$1
VALUE=$2
FILE=$3

set_setting $SETTING $FILE
```

```bash



