In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from dataset import ImageDataset, DomainDataset
from torch.utils.data import DataLoader,random_split
from dann import LeNetFeatureExtractor, LabelPredictor, DomainDiscriminator, DANN
from itertools import zip_longest
import torch.nn.functional as F


config={
    "mode": "train",
    "batch": 64,
    "epoch": 20,
    "lr": 1e-4,
    "cuda": 0,
    "norm": 2,
    "norm_type": "BN",
    "dropout": False,
    "weight_decay": False,
    "opt": "adam",
    "activation": "leaky_relu",
    "data_augmentation":False,
    "save":False,
    "dropout_rate":0.7,
    "alpha": 0.1,  # You can set your desired initial alpha value here.
    "domain_weight": 0.7,
    "source_domain_label": 0,
}

base_dir="dataset/"
device=config["cuda"]
device=torch.device(f"cuda:{device}")

# Define the hyperparameters
# batch_size = 128
# lr = 0.001
# num_epochs = 10
batch_size = config["batch"]
lr = config["lr"]
num_epochs = config["epoch"]
lambda_val = 0.1  # domain adversarial loss weight

# Define the datasets and data loaders
trainset = DomainDataset(base_dir + config["mode"], device=device, config=config, domain=0, train=True)
test_dataset = DomainDataset(base_dir + 'test', device=device, config=config, domain=1, train=False)
target_domain_dataset, testdataset = random_split(test_dataset, [int(0.1 * len(test_dataset)), len(test_dataset)-int(0.1 * len(test_dataset))])
trainloader = DataLoader(torch.utils.data.ConcatDataset([trainset, target_domain_dataset]), batch_size=config["batch"], shuffle=True, num_workers=16, drop_last=True)
testloader = DataLoader(testdataset, batch_size=config["batch"], shuffle=False, num_workers=16, drop_last=True)

# Define the DANN model and the optimizer
feature_extractor = LeNetFeatureExtractor(config=config)
label_predictor = LabelPredictor(num_classes=2,config=config)
domain_discriminator = DomainDiscriminator(config=config)
dann = DANN(feature_extractor, label_predictor, domain_discriminator)

optimizer = optim.Adam(dann.parameters(), lr=lr)

# Define the loss functions
clf_loss_fn = nn.CrossEntropyLoss()
domain_loss_fn = nn.BCELoss()

# Train the DANN model
for epoch in range(num_epochs):
    dann.train()
    num_correct_train = 0
    num_total_train = 0
    for i, (inputs, labels, domains) in enumerate(trainloader):
        # Set the domain labels (0 for source, 1 for target)
        # source_domain_labels = torch.zeros(inputs.size(0))
        # target_domain_labels = torch.ones(inputs.size(0))
        # domain_labels = torch.cat((source_domain_labels, target_domain_labels)).unsqueeze(1)
        domain_labels = domains.unsqueeze(1).float()
        # print("domain_labels: ",domain_labels.shape,type(domain_labels))

        # Zero the gradients
        optimizer.zero_grad()

        # Extract features and predict labels
        features = dann.feature_extractor(inputs)
        label_preds = dann.label_predictor(features)

        # Count Acc
        preds = F.softmax(label_preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        # print(pred_labels.shape,pred_labels,labels, labels.shape)
        num_correct_train += (pred_labels == labels).sum().item()
        num_total_train += labels.size(0)

        # Compute the label prediction loss
        clf_loss = clf_loss_fn(label_preds, labels)

        # Compute the domain classification loss
        domain_preds = dann.domain_discriminator(features)
        domain_loss = domain_loss_fn(domain_preds, domain_labels)

        # Compute the total loss and update the parameters
        total_loss = clf_loss + domain_loss
        total_loss.backward()
        optimizer.step()

        # Print the training statistics
        if (i+1) % 10 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Clf Loss: {:.4f}, Domain Loss: {:.4f}, Total Loss: {:.4f}"
                  .format(epoch+1, num_epochs, i+1, len(trainloader), clf_loss.item(), domain_loss.item(), total_loss.item()))
        
    print("Epoch [{}/{}], Train Accuracy: {:.2f}%".format(epoch+1, num_epochs, 100 * num_correct_train / num_total_train))

# Evaluate the model on the test set
dann.eval()
with torch.no_grad():
    num_correct = 0
    num_total = 0
    for i, (inputs, labels, domains) in enumerate(testloader):
        features = dann.feature_extractor(inputs)
        logits = dann.label_predictor(features)
        preds = F.softmax(logits, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        # print(pred_labels.shape,pred_labels,labels, labels.shape)
        num_correct += (pred_labels == labels).sum().item()
        num_total += labels.size(0)
    accuracy = 100 * num_correct / num_total
    print("Epoch [{}/{}], Test Accuracy: {:.2f}%".format(epoch+1, num_epochs, accuracy))

    # Set the model back to training mode
    dann.train()


Epoch [1/20], Step [10/21], Clf Loss: 0.5790, Domain Loss: 0.1384, Total Loss: 0.7174
Epoch [1/20], Step [20/21], Clf Loss: 0.4835, Domain Loss: 0.3741, Total Loss: 0.8576
Epoch [1/20], Train Accuracy: 75.00%
Epoch [2/20], Step [10/21], Clf Loss: 0.3128, Domain Loss: 0.3856, Total Loss: 0.6983
Epoch [2/20], Step [20/21], Clf Loss: 0.1848, Domain Loss: 0.1834, Total Loss: 0.3681
Epoch [2/20], Train Accuracy: 90.77%
Epoch [3/20], Step [10/21], Clf Loss: 0.2300, Domain Loss: 0.3590, Total Loss: 0.5890
Epoch [3/20], Step [20/21], Clf Loss: 0.1489, Domain Loss: 0.2110, Total Loss: 0.3600
Epoch [3/20], Train Accuracy: 94.42%
Epoch [4/20], Step [10/21], Clf Loss: 0.1094, Domain Loss: 0.2352, Total Loss: 0.3447
Epoch [4/20], Step [20/21], Clf Loss: 0.1040, Domain Loss: 0.1289, Total Loss: 0.2330
Epoch [4/20], Train Accuracy: 96.13%
Epoch [5/20], Step [10/21], Clf Loss: 0.0756, Domain Loss: 0.2218, Total Loss: 0.2974
Epoch [5/20], Step [20/21], Clf Loss: 0.0749, Domain Loss: 0.1212, Total Loss: