Skip to content

Virtual Adversarial Training (VAT) implementation for PyTorch

Notifications You must be signed in to change notification settings

lyakaap/VAT-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VAT-pytorch

Virtual Adversarial Training (VAT) implementation for Pytorch

Usage

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()

    vat_loss = VATLoss(xi=10.0, eps=1.0, ip=1)
    cross_entropy = nn.CrossEntropyLoss()

    # LDS should be calculated before the forward for cross entropy
    lds = vat_loss(model, data)
    output = model(data)
    loss = cross_entropy(output, target) + args.alpha * lds
    loss.backward()
    optimizer.step()

About

Virtual Adversarial Training (VAT) implementation for PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages