Skip to content

Commit

Permalink
For real?
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanthkoushik committed Jul 25, 2023
1 parent 1065dc9 commit c21e85f
Show file tree
Hide file tree
Showing 41 changed files with 4,793 additions and 12,610 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ repos:
language: system
files: "paper/.*"
pass_filenames: false
entry: sh -c "(cd paper && pipx run --spec shiny-mdc shinymdc -i main.md -o main.pdf -t stylish -m smalltabs=true)"
entry: sh -c "(cd paper && pipx run --spec git+https://github.com/jayanthkoushik/shiny-mdc shinymdc -i main.md -o main.pdf -t stylish -m smalltabs=true,nonidan=true)"
16 changes: 14 additions & 2 deletions alphanet/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
patheffects as pe,
pyplot as plt,
)
from matplotlib.container import BarContainer
from matplotlib.ticker import (
FixedFormatter,
FixedLocator,
Expand Down Expand Up @@ -817,6 +818,7 @@ class PlotSplitAccVsExp(_BaseMultiExpPlotCmd, BasePlotCmd):
xlabel: str = ""
legend_loc: str = "upper right"
legend_bbox_to_anchor: Optional[Tuple[float, float]] = None
show_titles: bool = True

@staticmethod
def _get_metric_desc(metric: Literal["euclidean", "cosine", "random"]) -> str:
Expand Down Expand Up @@ -891,8 +893,10 @@ def __call__(self):
fig=_fig, ax=_ax, left=True, right=True, bottom=False, top=True
)

if self.col is not None and len(cols) > 1:
if self.col is not None and len(cols) > 1 and self.show_titles:
_ax.set_title(_col)
else:
_ax.set_title("")

_legend = _ax.legend()
if _i == 0:
Expand Down Expand Up @@ -1872,8 +1876,16 @@ def __call__(self):
g.set_titles("")
for facet_name, ax in g.axes_dict.items():
ax_df = df[df["Dataset"] == facet_name]
ax.yaxis.set_major_formatter(PercentFormatter(xmax=len(ax_df)))
fmtr = PercentFormatter(xmax=len(ax_df))
ax.yaxis.set_major_formatter(fmtr)
ax.set_yticks([int(len(ax_df) * _p / 100) for _p in [15, 30, 45]])
bar_containers = [
_con for _con in ax.containers if isinstance(_con, BarContainer)
]
assert len(bar_containers) == 1
bar_container = bar_containers[0]
# pylint: disable=cell-var-from-loop,unnecessary-lambda
ax.bar_label(bar_container, fmt=lambda _n: fmtr(_n), label_type="center")
g.set_xlabels("")
g.set_ylabels("")
g.despine(top=True, left=True, right=True, bottom=False, trim=False)
Expand Down
16 changes: 16 additions & 0 deletions alphanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def __call__(self) -> TrainResult:
self.training.pred_scale,
).to(DEFAULT_DEVICE)

# # Initialize weights.
# _hact = self.alphanet.hact
# if _hact == "relu" or _hact == "leaky_relu":
# _a = 0.01 if _hact == "leaky_relu" else 0

# def _w_init_fn(_w):
# torch.nn.init.kaiming_normal_(_w, a=_a, nonlinearity=_hact)

# else:
# _w_init_fn = torch.nn.init.xavier_normal_
# _b_init_fn = torch.nn.init.zeros_

# for _param in self.alphanet.linear_layer__seq + [self.alphanet.conv_layer]:
# _w_init_fn(_param.weight)
# _b_init_fn(_param.bias)

self.training.ptopt.set_weights(alphanet_classifier.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
logging.info("setting up model...done")
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c21e85f

Please sign in to comment.