In [46]:
%load_ext autoreload
%autoreload 2

import torch
from frcnn_resnet import *
import math
from torchinfo import summary

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
import sys
sys.path.append('/Users/xinyuzhang/Documents/Evernote/quantzied_sparse/library/pypi')

from espcn import ESPCN
from cyclegan import define_G

In [23]:
espcn_stats = summary(ESPCN(3, 1), (1, 1, 170, 170))

In [24]:
def espcn_mem(bits=32, prune_feat=False, prune_weight=False):
    total = espcn_stats.total_params * bits + espcn_stats.total_output * bits 
    
    
    if prune_feat:
        total -= (64*170*170 * bits / 2)
    
    if prune_weight:
        total -= (64*32*3*3 * bits / 2)
    return total / 1e6

In [25]:
# 256
cyclegan_stats = summary(define_G(3, 3, 64, 'resnet_9blocks'), (1, 3, 256, 256))
pix2pix_stats = summary(define_G(3, 3, 64, 'unet_256'), (1, 3, 256, 256))

initialize network with normal
initialize network with normal


In [26]:
def pix2pix_mem(bits=32, prune_feat=False, prune_weight=False):
    total = pix2pix_stats.total_params * bits* (0.5 if prune_weight else 1)  + pix2pix_stats.total_output * bits* (0.5 if prune_feat else 1)
    
    if prune_feat:
        total += ((128 * 128 * 64 + 256*256*3 + 4*4*64*8 + 4*4*64*8) * bits / 2)  # first layer/last layer
    
    if prune_weight:
        total += ((4*4*64*3 + 4*4*128*3 + 4*4*64*8 + 4*4*64*8)  * bits / 2)
    return total/1e6

In [27]:
pix2pix_mem()

2268.192864

In [28]:
pix2pix_mem(prune_weight=True)

1397.977136

In [29]:
pix2pix_stats.total_output

16466944

In [30]:
def cyclegan_mem(bits=32, prune_feat=False, prune_weight=False):
    total = cyclegan_stats.total_params * bits* (0.5 if prune_weight else 1)  + cyclegan_stats.total_output * bits* (0.5 if prune_feat else 1)
    
    if prune_feat:
        total += ((256 * 256 * 64 + 128*128*128 + 128*128*64 + 3*256*256) * bits / 2)  # first layer/last layer
    
    if prune_weight:
        total += ((7*7*64*3 * 2)  * bits / 2)
    return total/1e6

In [31]:
resnet101 = ResNet101()
resnet101_head = ResNet101Head()
resnet101_stats = summary(resnet101, (1, 3, 480, 720))
resnet101_head_stats = summary(resnet101_head, (1, 3, 480, 720))

In [32]:
def frcnn_mem(bits=32, prune_feat=False, prune_weight=False):
    total = resnet101_stats.total_params * bits* (0.5 if prune_weight else 1)  + resnet101_stats.total_output * bits* (0.5 if prune_feat else 1)
    origin_head_total =  resnet101_head_stats.total_params * 32 + resnet101_head_stats.total_output * 32
    head_total = resnet101_head_stats.total_params * bits* (0.5 if prune_weight else 1)  + resnet101_head_stats.total_output * bits* (0.5 if prune_feat else 1)
    return (total - head_total + origin_head_total) / 1e9

In [20]:
frcnn_mem(prune_feat=False)

221.83597056

In [21]:
frcnn_mem(8, prune_feat=True, prune_weight=True)

84.712589312

# Performance Density for CycleGAN

In [67]:
1/67.1 / cyclegan_mem() * 1e6

3.2843623535690964e-06

In [36]:
1/81.8 / cyclegan_mem(8, prune_weight=True) * 1e6

11.226436157732458

In [37]:
1/83.72 / cyclegan_mem(8, prune_weight=True) * 1e6

10.968973694487756

In [38]:
1/100.4  / cyclegan_mem(8, prune_weight=True, prune_feat=True) * 1e6

16.671896630414196

In [39]:
1/89.35/ cyclegan_mem(8, prune_weight=True, prune_feat=True) * 1e6

18.733726040219203

# Performance Density for Pix2Pix

In [62]:
119.9 / pix2pix_mem()

0.05286146601685102

In [63]:
127.1/ pix2pix_mem(8, prune_weight=True) 

0.36366832254114917

In [64]:
123.5/ pix2pix_mem(8, prune_weight=True) 

0.35336772489246204

In [65]:
 154.8 / pix2pix_mem(8, prune_weight=True, prune_feat=True)

0.5362473039543251

In [66]:
135.0/ pix2pix_mem(8, prune_weight=True, prune_feat=True)

0.4676575325183067

# Performance Density for SR

In [47]:
32.84 / math.log(espcn_mem())

6.228346051769872

In [48]:
32.02 / math.log(espcn_mem(prune_feat=True))

6.268580591959094

In [49]:
32.03 / math.log(espcn_mem(8, prune_feat=True))

8.606235711190584

In [50]:
31.03 / math.log(espcn_mem(8, prune_feat=True, prune_weight=True))

8.341544037556975

In [51]:
31.66/ math.log(espcn_mem(8, prune_feat=True, prune_weight=True))

8.510901844313691

In [52]:
32.51   / math.log(espcn_mem(8, prune_weight=True))

8.368385063163224

In [53]:
32.54/ math.log(espcn_mem(8, prune_weight=True))

8.37610735021013

# Performance Density for Detection

In [54]:
 74.47 /frcnn_mem()

0.3356984884462553

In [56]:
 73.04 /frcnn_mem(prune_feat=True)

0.5066727997494486

In [57]:
 73 /frcnn_mem(8, prune_feat=True)

0.8600205913038409

In [58]:
 73 /frcnn_mem(8, prune_feat=True, prune_weight=True)

0.8617373237304546

In [59]:
70.13 / frcnn_mem(8, prune_feat=True, prune_weight=True)

0.8278580618248874

In [60]:
74.44 / frcnn_mem(8,  prune_weight=True)

0.714858168009346

In [61]:
74.13 / frcnn_mem(8,  prune_weight=True)

0.7118811928335951