In [None]:
from calflops import calculate_flops
from model import InterpretableResnet2, InterpretableViT, CBM, ViTConceptModel
from processing.utils import get_info_from_lattice
from argparse import Namespace

In [None]:
args = Namespace()
args.dataset = "awa2"  # "awa2" / "cifar100" / "inet100"
args.lattice_levels = [2, 1] # Change with lattice levels
args.backbone_layer_ids = [3, 4] # change with backbone positions
args.pretrained_clfs = False

# Change the remaining accordingly
if args.dataset == "awa2":
    args.data_path = '././DATA/Animals_with_Attributes2'
    args.concept_file = '././DATA/concepts/awa2_concepts.json'
    args.lattice_path = '././DATA/lattices/awa2_context.pkl'
    args.num_classes = 50
    args.num_attrs = 85
    args.backbone = 'resnet18'
elif args.dataset == "inet100":
    args.data_path = '././DATA/inet100'
    args.concept_file = '././DATA/concepts/inet100_concepts.json'
    args.lattice_path = '././DATA/lattices/inet100_context.pkl'
    args.num_classes = 100
    args.num_attrs = 700
    args.backbone = 'resnet50'
elif args.dataset == "cifar100":
    args.data_path = '././DATA/cifar100'
    args.concept_file = '././DATA/concepts/cifar100_concepts.json'
    args.lattice_path = '././DATA/lattices/cifar100_context.pkl'
    args.num_classes = 100
    args.num_attrs = 700
    args.backbone = 'resnet50'

In [None]:
perlevel_intents, perlevel_fcs = get_info_from_lattice(args.lattice_path, args.lattice_levels)

In [None]:
# model = InterpretableResnet2(
#     intent_list=perlevel_intents,
#     fc_list=perlevel_fcs,
#     backbone_layer_ids=args.backbone_layer_ids,
#     num_classes=args.num_classes,
#     backbone_name=args.backbone
# )

# model = InterpretableViT(
#     intent_list=perlevel_intents,
#     fc_list=perlevel_fcs,
#     backbone_layer_ids=args.backbone_layer_ids,
#     num_classes=args.num_classes,
#     model_name=args.backbone
# )

model = CBM(
    model_name=args.backbone,
    num_classes=args.num_classes,
    num_attrs=args.num_attrs,
)

# model = ViTConceptModel(
#     model_name=args.backbone,
#     num_classes=args.num_classes,
#     num_concepts=args.num_attrs,
# )


In [None]:
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)

print("FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
args.backbone = 'vit_base_patch16_224'



In [None]:
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)

print("FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
from cem.models.cem import ConceptEmbeddingModel
from cem.train.utils import wrap_pretrained_model
from torchvision.models import resnet18, resnet50
from calflops import calculate_flops

In [None]:
cem_model = ConceptEmbeddingModel(
  n_concepts=700, # Number of training-time concepts
  n_tasks=100, # Number of output labels
  emb_size=16,
  concept_loss_weight=0.1,
  learning_rate=1e-3,
  optimizer="adam",
  c_extractor_arch=wrap_pretrained_model(resnet50), # Replace this appropriately
  training_intervention_prob=0.25, # RandInt probability
)

In [None]:
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=cem_model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)

print("FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))