Skip to content

jaber628/fret

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 

Repository files navigation

FRET: Feature Redundancy Elimination for Test Time Adaptation

This repo is officical PyTorch implement of 'Feature Redundancy Elimination for Test Time Adaptation'
This codebase is mainly based on TSD and AETTA.

Dependence

We use python==3.8.13, other packages including:

torch==1.12.0+cu113
torchvision==0.13.0+cu113
numpy==1.24.4
pandas==2.0.3
tqdm==4.66.2
timm==0.9.16
scikit-learn==1.3.2 
pillow==10.3.0

We also share our python environment that contains all required python packages. Please refer to the ./FRET.yml file.
You can import our environment using conda:

conda env create -f FRET.yml -n FRET

Dataset

Download PACS and OfficeHome datasets used in our paper from:
PACS
OfficeHome
Download them from the above links, and organize them as follows.

|-your_data_dir
  |-PACS
    |-art_painting
    |-cartoon
    |-photo
    |-sketch
  |-OfficeHome
    |-Art
    |-Clipart
    |-Product
    |-RealWorld

To download the CIFAR10/CIFAR10-C and CIFAR100/CIFAR100-C datasets ,run the following commands:

$. download_cifar10c.sh        #download CIFAR10/CIFAR10-C datasets
$. download_cifar100c.sh       #download CIFAR100/CIFAR100-C datasets

Also, you can download the VLCS, DomainNet, ImageNet-C from the links below.

Train source model

Please use train.py to train the source model. For example:

cd code/
python train.py --dataset PACS \
                --data_dir your_data_dir \
                --opt_type Adam \
                --lr 5e-5 \
                --max_epoch 50 \
                --net resnet18 \
                --test_envs 0  \

Change --dataset PACS for other datasets, such as office-home,VLCS,DomainNet, CIFAR-10, CIFAR-100.
Set --net to use different backbones, such as resnet50, ViT-B16.
Set --test_envs 0 to change the target domain.
For CIFAR-10 and CIFAR-100, there is no need to set the --data_dir and --test_envs .

Test time adaptation

For domain datasets such as PACS and OfficeHome, run the following code:

python unsupervise_adapt.py --dataset PACS \
                            --data_dir your_data_dir \
                            --adapt_alg G-FRET \ 
                            --pretrain_dir your_pretrain_model_dir \
                            --lr 1e-4 \
                            --net resnet18 \
                            --test_envs 0

For corrupted datasets such as CIFAR10-C and CIFAR100-C, run the following code:

python unsupervise_adapt_corrupted.py --dataset CIFAR-10 \
                                      --data_dir your_data_dir \
                                      --adapt_alg G-FRET \ 
                                      --pretrain_dir your_pretrain_model_dir \
                                      --lr 1e-4 \
                                      --net resnet18

Change --adapt_alg G-FRET to use different methods of test time adaptation, e.g. S-FRET, TSD, BN, Tent.
--pretrain_dir denotes the path of source model, e.g. ./train_outputs/model.pkl.
For G-FRET, we have set default parameters in our code. For better results, you might consider adjusting the parameters --lam_GFRET,--filter_K, and --GFRET_K. For guidance on selecting hyperparameters, please refer to our paper.

Tested Environment

We tested our code in the environment described below.

OS: Ubuntu 18.04.6 LTS
GPU: NVIDIA GeForce RTX 4090
GPU Driver Version: 535.129.03
CUDA Version: 12.2

About

FRET

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 98.7%
  • Shell 1.3%