# How to use NFM for your own data and task

This notebook is not to instruct on how to use the pre-coded packages (NFM + predictor) in each task sub-folder, but to demonstrate on how to use the backbone NFM in the model folder and train NFM on one's own data. 


## Basic setup  

In [None]:
# Seed
np.random.seed(88)
torch.manual_seed(88)

# Generate "mini-batch" input data (replace this with your own dataset and dataloader) 
train_x_mb = torch.randn((32, 720, 10)) # To show the format of a mini-batch

# Generate mini-batch sample target data (we will show different cases)

## True label for classification (just examplified and not used in this demo)
train_y_label = torch.randint(0, 10, (32,))
target_lenth_label = 0

## True prediction for Forecasting (horizon = 180)
train_y_horizon = torch.randn((32, 180, 10))
target_lenth_horizon = 180

print("x in: ", train_x_mb.shape)
print("target horizon: ", train_y_horizon.shape)

## Setup arguments and NFM instantiation

In [None]:
# Model arguments (forecasting setup)

N = 720 # input length
L = target_lenth_horizon # or "target_lenth_label" for classification

Fx = N # sampling rate of input time series. Note that we always assume Tx = 1
Fy = Fx # sampling rate of output latent variable (note that Fy is not always the same as Fx -> see anomaly detection in our main work)

training_T_F = [Fy, Fx, L, N] # [output sr, input sr, target length, input length] -> for classification, you only need to change L to 0 or target_lenth_label
testing_T_F = [Fy, Fx, L, N] # we use this setup at testing time

hypervars = HyperVariables(sets_in_training = training_T_F,
                            sets_in_testing = testing_T_F,
                            C_ = 10,
                            freq_span = -1, # full spectrum prediction
                            channel_dependence = False, # False makes NFM chennel-independent
                            
                            # Mixing block
                            filter_type = "INFF",
                            hidden_dim = 32,
                            inff_siren_hidden = 32,
                            inff_siren_omega = 30,
                            layer_num = 1, # number of mixing blocks
                            
                            lft = True,
                            lft_siren_dim_in = 32,
                            lft_siren_hidden = 32,
                            lft_siren_omega = 30,

                            print_inf = False # turn off printing inf 
                            )

# Construct NFM 
NFM_backbone = NFM_general(hypervars)

# Prediction head (output dim is 1 for channel independence)
predictor = nn.Linear(32, 1)

## Move NFM_backbone and predictor to a GPU if one is available and no need to move hypervars

## Training

In [None]:
# Optimizer (add schedulers, decay, etc. as needed)
opt = torch.optim.Adam(NFM_backbone.parameters(), lr = 0.0001)
# Criterion
criterion = nn.MSELoss()

# Training 

hypervars.training_set() # this NFM to act on "training_T_F" setup
NFM_backbone.train()
for epoch in range(1):
    opt.zero_grad()
    train_x_in = hypervars.input_(train_x_mb) # This rearranges the minibatch according to the channel independence.

    # InstanceNorm for forecasting (1)
    train_x_in_mean = torch.mean(train_x_in, dim=1, keepdim=True)
    train_x_in_std= torch.sqrt(torch.var(train_x_in, dim=1, keepdim=True)+ 1e-5)
    train_x_in = train_x_in - train_x_in_mean
    train_x_in = train_x_in / train_x_in_std
    
    # (2) NFM forward processing
    z, _, _, _, _ = NFM_backbone(train_x_in) #the output length is (L = N + horizon) !!
    y = predictor(z)

    # reverse instanceNorm (3)
    y = y * train_x_in_std + train_x_in_mean
    y, y_freq = hypervars.output_(y) # This rearranges the output minibatch according to the channel independence.


    # compute loss (4) - loss over full span (L) 
    fullspan_loss = criterion(y, torch.cat((train_x_mb, train_y_horizon), dim = 1).detach() )
    fullspan_loss.backward()
    opt.step()


print("** From Training ** ")
print("Input x: ", train_x_mb.shape)
print("Output z: ", z.shape)
print("Output y: ", y.shape)
print("Output y_freq: ", y_freq.shape)
print("Output horizon y : ", y[:,N:,:].shape)


# Same for classification setup except that the step (1) and (3) are not necessary

## Testing at different sampling rate ($m_f \neq 1$)
Testing the trained NFM on the input time series sampled at different rate is easy. 

You can simply do this by setting a new input sampling rate and input length, and let NFM to work on this set of arguments. 

In [None]:
# Testing-time inputs sampled at half the input sampling rate
test_x = torch.randn((32, 360, 10)) # N is 360 (downsampled over the original time span) and so Fx = 360 while no change in T_x
test_y_horizon = torch.randn((32, 180, 10)) # no change in target prediction

N_testing_time = 360
Fx_testing_time = N_testing_time

# Update testing_T_F 
testing_T_F = [Fy, Fx_testing_time, L, N_testing_time]
hypervars.sets_in_testing = testing_T_F

# Apply the update to NFM
hypervars.testing_set()

# Inference
test_x_in = hypervars.input_(test_x)

## InstanceNorm for forecasting (1)
test_x_in_mean = torch.mean(test_x_in, dim=1, keepdim=True)
test_x_in_std= torch.sqrt(torch.var(test_x_in, dim=1, keepdim=True)+ 1e-5)
test_x_in = test_x_in - test_x_in_mean
test_x_in = test_x_in / test_x_in_std

## (2) NFM forward processing
test_z, _, _, _, _ = NFM_backbone(test_x_in)
test_y = predictor(test_z)

## Reverse instanceNorm (3)
test_y = test_y * test_x_in_std + test_x_in_mean
y, y_freq = hypervars.output_(test_y) # This rearranges the output minibatch according to the channel independence.

y_horizon = y[:,N:,:]


print("** From Testing at different SR** ")
print("Input testing time x: ", test_x.shape)
print("Output y from downsampled x: ", y.shape )
print("Output y_freq from downsampled x: ", y_freq.shape )
print("Output horizon y from downsampled x: ", y_horizon.shape )

# Same procedure is applied to classification except that the step (1) and (3) are not necessary.

The above example is for the case of $m_f > 1$ at testing time, but a case of $m_f < 1$ can also be made following the same procedure as above.