# Import Modules

In [47]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader

import utils

# Prepare Data

In [48]:
source_X = pd.read_csv("./deep_occupancy_detection/data/1_X_train.csv").values
target_X = pd.read_csv("./deep_occupancy_detection/data/2_X_train.csv").values

scaler = preprocessing.StandardScaler()
scaler.fit(source_X)
source_X = scaler.transform(source_X)
target_X = scaler.transform(target_X)

source_target_X = np.concatenate([source_X, target_X], axis=0)
source_target_y_domain = np.concatenate([np.zeros(source_X.shape[0]), np.ones(target_X.shape[0])], axis=0)

In [49]:
source_target_X = torch.Tensor(source_target_X)
source_target_y_domain = torch.Tensor(source_target_y_domain)

source_target_X = source_target_X.to(utils.DEVICE)
source_target_y_domain = source_target_y_domain.to(utils.DEVICE)

source_target_ds = TensorDataset(source_target_X, source_target_y_domain)
source_target_loader = DataLoader(source_target_ds, batch_size=16, shuffle=True)

# 1. 

# 2. Marginal Distribution Discrepancy between Source and Target

In [50]:
domain_classifier = utils.Decoder(input_size=source_target_X.shape[1], output_size=1).to(utils.DEVICE)
optimizer = optim.Adam(domain_classifier.parameters(), lr=0.001)
criterion = nn.BCELoss()

In [51]:
for _ in range(100):
    for source_target_X_batch, source_target_y_domain_batch in source_target_loader:
        # Forward
        pred_y = domain_classifier(source_target_X_batch)
        pred_y = torch.sigmoid(pred_y).reshape(-1)
        loss = criterion(pred_y, source_target_y_domain_batch)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Update Params
        optimizer.step()

In [52]:
pred_y = domain_classifier(source_target_X)
pred_y = torch.sigmoid(pred_y).reshape(-1)
pred_y = pred_y > 0.5

acc = sum(pred_y == source_target_y_domain) / source_target_y_domain.shape[0]
print(f"Domain Classification Accuracy: {acc}")

Domain Classification Accuracy: 0.7558059692382812


# 3. 