Skip to content

Commit

Permalink
Merge pull request #179 from labmlai/slider
Browse files Browse the repository at this point in the history
smoothing fixes
  • Loading branch information
lakshith-403 committed Feb 21, 2024
2 parents 6a04503 + 79921aa commit affdde2
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 49 deletions.
2 changes: 1 addition & 1 deletion app/server/labml_app/analyses/experiments/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def defaults(cls):
base_experiment="",
step_range=[-1, -1],
is_base_distributed=False,
smooth_value=-1,
smooth_value=1,
)

def update_preferences(self, data: preferences.PreferencesData) -> None:
Expand Down
2 changes: 1 addition & 1 deletion app/server/labml_app/analyses/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def defaults(cls):
errors=[],
step_range=[-1, -1],
focus_smoothed=True,
smooth_value=-1
smooth_value=1
)

def update_preferences(self, data: PreferencesData) -> None:
Expand Down
3 changes: 2 additions & 1 deletion app/ui/src/analyses/experiments/chart_wrapper/card.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ export class CardWrapper {
basePlotIdx: this.basePlotIdx,
width: this.width,
isDivergent: true,
onlySelected: true
onlySelected: true,
smoothValue: this.smoothValue
}).render($)
})
}
Expand Down
3 changes: 2 additions & 1 deletion app/ui/src/analyses/experiments/chart_wrapper/view.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ export class ViewWrapper {
let changeHandler = new ChangeHandlers.ToggleChangeHandler(this, idx, true)
changeHandler.change()
},
isDivergent: true
isDivergent: true,
smoothValue: this.dataStore.smoothValue
})
this.sparkLines.render($)
})
Expand Down
30 changes: 23 additions & 7 deletions app/ui/src/components/charts/lines/chart.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import d3 from "../../../d3"
import {WeyaElement, WeyaElementFunction} from '../../../../../lib/weya/weya'
import {ChartOptions} from '../types'
import {SeriesModel} from "../../../models/run"
import {fillPlotPreferences, getExtent, getLogScale, getScale, trimSteps} from "../utils"
import {PointValue, SeriesModel} from "../../../models/run"
import {
getExtent,
getLogScale,
getScale,
getSmoothWindow,
smoothSeries,
trimSteps
} from "../utils"
import {LineFill, LinePlot} from "./plot"
import {BottomAxis, RightAxis} from "../axis"
import {formatStep} from "../../../utils/value"
Expand Down Expand Up @@ -50,7 +57,9 @@ export class LineChart {
private svgBoundingClientRect: DOMRect
private readonly uniqueItems: Map<string, number>
private readonly focusSmoothed: boolean
private readonly smoothValue: number

private readonly currentSmoothedSeries: PointValue[][]
private readonly baseSmoothedSeries: PointValue[][]

constructor(opt: LineChartOptions) {
this.currentSeries = opt.series
Expand All @@ -63,7 +72,6 @@ export class LineChart {
this.baseSeries = trimSteps(this.baseSeries, opt.stepRange[0], opt.stepRange[1])
this.currentSeries = trimSteps(this.currentSeries, opt.stepRange[0], opt.stepRange[1])
this.focusSmoothed = opt.focusSmoothed
this.smoothValue = opt.smoothValue

this.uniqueItems = new Map<string, number>()
this.axisSize = 30
Expand All @@ -74,10 +82,18 @@ export class LineChart {
this.chartHeight = Math.round(Math.min(this.chartWidth, windowHeight) / 2)

// TODO show something if everything is not selected
let smoothWindow = getSmoothWindow(this.currentSeries, this.baseSeries, opt.smoothValue)

this.filteredBaseSeries = this.baseSeries.filter((_, i) => this.basePlotIndex[i] == 1)
this.filteredCurrentSeries = this.currentSeries.filter((_, i) => this.currentPlotIndex[i] == 1)

this.currentSmoothedSeries = this.filteredCurrentSeries.map(s => {
return smoothSeries(s.series, smoothWindow)
})
this.baseSmoothedSeries = this.filteredBaseSeries.map(s => {
return smoothSeries(s.series, smoothWindow)
})

const stepExtent = getExtent(this.filteredBaseSeries.concat(this.filteredCurrentSeries).map(s => s.series), d => d.step, false, true)
this.xScale = getScale(stepExtent, this.chartWidth, false)

Expand Down Expand Up @@ -198,7 +214,7 @@ export class LineChart {
if (this.filteredCurrentSeries.length < 3 && this.filteredBaseSeries.length == 0) {
this.filteredCurrentSeries.map((s, i) => {
new LineFill({
series: s.series,
series: this.currentSmoothedSeries[i],
xScale: this.xScale,
yScale: this.yScale,
color: document.body.classList.contains("light")
Expand All @@ -219,7 +235,7 @@ export class LineChart {
: this.chartColors.getColor(this.uniqueItems.get(s.name)),
renderHorizontalLine: true,
smoothFocused: this.focusSmoothed,
smoothValue: this.smoothValue
smoothedSeries: this.currentSmoothedSeries[i],
})
this.linePlots.push(linePlot)
linePlot.render($)
Expand All @@ -236,7 +252,7 @@ export class LineChart {
isBase: true,
renderHorizontalLine: true,
smoothFocused: this.focusSmoothed,
smoothValue: this.smoothValue
smoothedSeries: this.baseSmoothedSeries[i],
})
this.linePlots.push(linePlot)
linePlot.render($)
Expand Down
5 changes: 2 additions & 3 deletions app/ui/src/components/charts/lines/plot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export interface LinePlotOptions extends PlotOptions {
isBase?: boolean
renderHorizontalLine?: boolean
smoothFocused?: boolean
smoothValue: number
smoothedSeries: PointValue[]
}

export class LinePlot {
Expand Down Expand Up @@ -42,8 +42,7 @@ export class LinePlot {
return d.step
}).left

let smoothWindow = mapRange(opt.smoothValue, 1, 100, 1, this.series.length/10)
this.smoothedSeries = smoothSeries(this.series, opt.smoothValue)
this.smoothedSeries = opt.smoothedSeries

this.smoothedLine = d3.line()
.curve(d3.curveMonotoneX)
Expand Down
31 changes: 20 additions & 11 deletions app/ui/src/components/charts/spark_lines/chart.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {WeyaElementFunction} from '../../../../../lib/weya/weya'
import {ChartOptions} from '../types'
import {SeriesModel} from "../../../models/run"
import {fillPlotPreferences, getExtent} from "../utils"
import {getExtent, getSmoothWindow, smoothSeries} from "../utils"
import {SparkLine} from "./spark_line"
import ChartColors from "../chart_colors"
import {DefaultLineGradient} from "../chart_gradients"
Expand All @@ -15,6 +15,7 @@ interface CompareSparkLinesOptions extends ChartOptions {
isMouseMoveOpt?: boolean
isDivergent?: boolean
onlySelected?: boolean
smoothValue: number
}

export class SparkLines {
Expand All @@ -24,8 +25,6 @@ export class SparkLines {
basePlotIdx: number[]
isEditable: boolean
rowWidth: number
minLastValue: number
maxLastValue: number
isMouseMoveOpt: boolean
stepExtent: [number, number]
onCurrentSelect?: (i: number) => void
Expand All @@ -36,6 +35,9 @@ export class SparkLines {
uniqueItems: Map<string, number>
onlySelected: boolean

private readonly currentLastValues: number[]
private readonly baseLastValues: number[]

constructor(opt: CompareSparkLinesOptions) {
this.currentSeries = opt.series
this.baseSeries = opt.baseSeries ?? []
Expand All @@ -60,8 +62,17 @@ export class SparkLines {
}
}

this.maxLastValue = Math.max(...lastValues)
this.minLastValue = Math.min(...lastValues)
this.currentLastValues = []
this.baseLastValues = []
let smoothWindow = getSmoothWindow(this.currentSeries, this.baseSeries, opt.smoothValue)
for (let i = 0; i < this.currentSeries.length; i++) {
let smoothedSeries = smoothSeries(this.currentSeries[i].series, smoothWindow)
this.currentLastValues.push(smoothedSeries[smoothedSeries.length - 1].value)
}
for (let i = 0; i < this.baseSeries.length; i++) {
let smoothedSeries = smoothSeries(this.baseSeries[i].series, smoothWindow)
this.baseLastValues.push(smoothedSeries[smoothedSeries.length - 1].value)
}

this.stepExtent = getExtent(this.currentSeries.concat(this.baseSeries).map(s => s.series), d => d.step)

Expand Down Expand Up @@ -94,12 +105,11 @@ export class SparkLines {
stepExtent: this.stepExtent,
width: this.rowWidth,
onClick: onClick,
minLastValue: this.minLastValue,
maxLastValue: this.maxLastValue,
color: document.body.classList.contains("light")
? this.chartColors.getSecondColor(this.uniqueItems.get(s.name))
: this.chartColors.getColor(this.uniqueItems.get(s.name)),
isMouseMoveOpt: this.isMouseMoveOpt
isMouseMoveOpt: this.isMouseMoveOpt,
lastSmoothedValue: this.currentLastValues[i]
})
this.sparkLines.push(sparkLine)
})
Expand All @@ -120,13 +130,12 @@ export class SparkLines {
stepExtent: this.stepExtent,
width: this.rowWidth,
onClick: onClick,
minLastValue: this.minLastValue,
maxLastValue: this.maxLastValue,
color: document.body.classList.contains("light")
? this.chartColors.getColor(this.uniqueItems.get(s.name))
: this.chartColors.getSecondColor(this.uniqueItems.get(s.name)),
isMouseMoveOpt: this.isMouseMoveOpt,
isBase: true
isBase: true,
lastSmoothedValue: this.baseLastValues[i]
})
this.sparkLines.push(sparkLine)
})
Expand Down
13 changes: 5 additions & 8 deletions app/ui/src/components/charts/spark_lines/spark_line.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@ export interface SparkLineOptions {
width: number
stepExtent: [number, number]
selected: number
minLastValue: number
maxLastValue: number
onClick?: () => void
isMouseMoveOpt?: boolean
color: string
isBase?: boolean
lastSmoothedValue: number
}

export class SparkLine {
series: PointValue[]
name: string
minLastValue: number
maxLastValue: number
color: string
selected: number
titleWidth: number
Expand All @@ -39,6 +36,7 @@ export class SparkLine {
bisect: d3.Bisector<number, number>
linePlot: LinePlot
isBase: boolean
lastSmoothedValue: number

constructor(opt: SparkLineOptions) {
this.series = opt.series
Expand All @@ -52,9 +50,8 @@ export class SparkLine {
this.color = this.selected >= 0 ? opt.color : getBaseColor()
this.chartWidth = Math.min(300, Math.round(opt.width * .60))
this.titleWidth = (opt.width - this.chartWidth) / 2
this.minLastValue = opt.minLastValue
this.maxLastValue = opt.maxLastValue
this.isBase = opt.isBase ?? false
this.lastSmoothedValue = opt.lastSmoothedValue

this.yScale = getScale(getExtent([this.series], d => d.value, true), -25)
this.xScale = getScale(opt.stepExtent, this.chartWidth)
Expand Down Expand Up @@ -90,7 +87,7 @@ export class SparkLine {
} else {
this.secondaryElem.textContent = ''
}
this.primaryElem.textContent = formatFixed(last.smoothed, 6)
this.primaryElem.textContent = formatFixed(this.lastSmoothedValue, 6)
}

render($: WeyaElementFunction) {
Expand All @@ -112,7 +109,7 @@ export class SparkLine {
yScale: this.yScale,
color: '#7f8c8d',
isBase: this.isBase,
smoothValue: 1
smoothedSeries: this.series
})
this.linePlot.render($)
})
Expand Down
50 changes: 34 additions & 16 deletions app/ui/src/components/charts/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,29 @@ export function toDate(time: number) {
export function smoothSeries(series: PointValue[], windowSize: number): PointValue[] {
let result: PointValue[] = []
windowSize = ~~windowSize
if (series.length < windowSize) {
let extraWindow = windowSize / 2
extraWindow = ~~extraWindow

if (series.length <= windowSize) {
return series
}
for (let i = windowSize; i < series.length; i++) {
let sumX = 0, sumY = 0
for (let j = i - windowSize; j < i; j++) {
sumY += series[j].smoothed
sumX += series[i].step
}
let avgX = sumX / windowSize
let avgY = sumY / windowSize
result.push(<PointValue>{step: avgX, value: avgY, smoothed: avgY})

let count = 0
let total = 0

for (let i = 0; i < series.length + extraWindow; i++) {
let j = i - extraWindow
if (i < series.length) {
total += series[i].smoothed
count++
}
if (j - extraWindow - 1 >= 0) {
total -= series[j - extraWindow - 1].smoothed
count--
}
if (j>=0) {
result.push({step: series[j].step, value: total / count, smoothed: total / count})
}
}
return result
}
Expand All @@ -130,12 +141,7 @@ export function fillPlotPreferences(series: SeriesModel[], currentPlotIdx: numbe

let plotIdx = []
for (let s of series) {
let name = s.name.split('.')
if (name[0] === 'loss') {
plotIdx.push(1)
} else {
plotIdx.push(-1)
}
plotIdx.push(-1)
}

return plotIdx
Expand Down Expand Up @@ -191,4 +197,16 @@ export function trimSteps(series: SeriesModel[], min: number, max: number) : Ser

return res
})
}

export function getSmoothWindow(currentSeries: SeriesModel[], baseSeries: SeriesModel[], smoothValue: number): number {
let maxSeriesLength = Math.max(Math.max(...baseSeries.map(s=>s.series.length)),
Math.max(...currentSeries.map(s=>s.series.length)))
let smoothWindow = mapRange(smoothValue, 1, 100, 1, maxSeriesLength/10)

if (smoothWindow<=0 || smoothWindow>=maxSeriesLength) {
smoothWindow = 1
}

return smoothWindow
}

0 comments on commit affdde2

Please sign in to comment.