# Exercise 34 (2) - Data-Driven Solver: identify reduced basis
### Task
Using the generated data, identify a reduced basis via singular value decomposition. Modify the level of truncation for the wave pressure with `uTruncation` and for the wave speed with `cTruncation`. How does the level of truncation affect the reconstructions?  

### Learning goals
- Understand how to simplify complex datasets via dimensionality reduction methods
- Familiarize yourself with the effect of the singular value decomposition's truncation

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader

In [None]:
import DataSet

In [None]:
torch.manual_seed(2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Pre-processing

**loading settings of measurements**

In [None]:
settings = pd.read_csv("dataset1DFWI/settings.csv")
dataset = DataSet.FullWaveFormInversionDataset1D(settings, device)
datasetTraining, datasetValidation = torch.utils.data.random_split(dataset, [0.9, 0.1],
                                                                   generator=torch.Generator().manual_seed(2))

## Identification of reduced basis

In [None]:
dataloaderAll = DataLoader(datasetTraining, batch_size=len(datasetTraining))
uAll, cAll, _ = next(iter(dataloaderAll))

## Reduced order identification of encoding

**truncation level (wave pressure)**

In [None]:
uTruncation = 25

**truncated basis**

In [None]:
uReshaped = uAll.reshape(-1, uAll.shape[-1])
uSVD = torch.svd(uReshaped)
Vu = uSVD[2].t()[:uTruncation]

torch.save(torch.as_tensor(Vu), "dataset1DFWI/measurementBasis.pt")

## Reduced order identification of decoding

**truncation level (wave speed)**

In [None]:
cTruncation = 50

**truncated basis**

In [None]:
cSVD = torch.svd(cAll)
Vc = cSVD[2].t()[:cTruncation]

torch.save(torch.as_tensor(Vc), "dataset1DFWI/materialBasis.pt")

## Post-processing

**SVD coefficients of wave pressure**

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot([uTruncation, uTruncation], [torch.min(uSVD[1]), torch.max(uSVD[1])], 'r', linewidth=2, label="truncation")
ax.plot(uSVD[1], 'k.', label="singular values")
ax.set_yscale('log')
ax.set_xlabel("$n$")
ax.set_ylabel("Singular values $\Sigma_{n}$")
ax.grid()
ax.legend()
fig.tight_layout()
plt.show()

**reconstruction of wave pressure**

In [None]:
t = np.linspace(0, (settings.N[0] + 1) * settings.dt[0], settings.N[0] + 1)
i = 200
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(t, uReshaped[i], 'k', linewidth=3, label="ground truth")
ax.plot(t, (uReshaped[i] @ Vu.t()) @ Vu, 'r:', linewidth=2, label="SVD reconstruction")
ax.grid()
ax.set_xlabel("$t$")
ax.set_ylabel("$\\tilde{u}(t)$")
ax.legend()
fig.tight_layout()
plt.show()

**SVD coefficients of wave speed**

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot([cTruncation, cTruncation], [torch.min(cSVD[1]), torch.max(cSVD[1])], 'r', linewidth=2, label="truncation")
ax.plot(cSVD[1], 'k.', label="singular values")
ax.set_yscale('log')
ax.set_xlabel("$n$")
ax.set_ylabel("Singular values $\Sigma_{n}$")
ax.grid()
ax.legend(loc="best")
fig.tight_layout()
plt.show()

**reconstruction of wave speed**

In [None]:
x = np.linspace(0, settings.Lx[0], settings.Nx[0] + 1)
i = 20
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(x, cAll[i, 1:-1], 'k', linewidth=3, label="ground truth")
ax.plot(x, ((cAll[i] @ Vc.t()) @ Vc)[1:-1], 'r:', linewidth=2, label="SVD\nreconstruction")
ax.grid()
ax.set_xlabel("$x$")
ax.set_ylabel("$\\tilde{c}(t)$")
ax.legend(loc="best")
fig.tight_layout()
plt.show()